In [None]:
from fastai.vision.all import *
import pandas as pd
import numpy as np
import albumentations
import cv2
import spacy
from PIL import Image
import os
from tqdm import tqdm
import random

# Generate Images

In [None]:
random.seed(42) # consistent color palette
df = pd.read_csv('../input/commonlitreadabilityprize/test.csv')

In [None]:
# Override node attributes to customise the plot
from spacy.tokens import Token
Token.set_extension('plot', default={})  # Create a token underscore extension

In [None]:
# https://github.com/cyclecycle/visualise-spacy-tree/blob/master/visualise_spacy_tree/visualise_spacy_tree.py

import pydot

DEFAULT_NODE_ATTRS = {
    'color': 'cyan',
    'shape': 'box',
    'style': 'filled',
}

def node_label(token):
    try:
        label = token._.plot['label']
    except:
        label = '{0} [{1}]\n({2} / {3})'.format(
            token.orth_,
            token.i,
            token.pos_,
            token.tag_
        )
    return label


def get_edge_label(from_token, to_token):
    label = from_token.dep_
    return label


def to_pydot(tokens, get_edge_label=get_edge_label):
    graph = pydot.Dot(graph_type='graph')

    # Add nodes to graph
    idx2node = {}
    for token in tokens:
        try:
            plot_attrs = token._.plot
        except AttributeError:
            plot_attrs = {}
        for attr, val in DEFAULT_NODE_ATTRS.items():
            if attr not in plot_attrs:
                plot_attrs[attr] = val
        label = node_label(token)
        plot_attrs['name'] = token.i
        plot_attrs['label'] = label
        node = pydot.Node(**plot_attrs)
        idx2node[token.i] = node
        graph.add_node(node)

    '''Add edges'''
    for token in tokens:
        if token.dep_ == 'ROOT':
            continue
        if token.head not in tokens:
            continue
        from_token = token
        to_token = token.head
        from_node = idx2node[from_token.i]
        to_node = idx2node[to_token.i]
        label = get_edge_label(from_token, to_token)
        edge_color = dep2color[label]
        edge = pydot.Edge(
            to_node, from_node,
            fontsize=12,
            color=edge_color
        )
        graph.add_edge(edge)
    return graph

def create_png(tokens, prog=None):
    graph = to_pydot(tokens)
    png = graph.create_png(prog=prog)
    return png

In [None]:
deps = ["ROOT", "acl", "acomp", "advcl", "advmod", "agent", "amod", "appos", "attr", "aux", "auxpass", 
        "case", "cc", "ccomp", "compound", "conj", "csubj", "csubjpass", "dative", "dep", "det", "dobj", 
        "expl", "intj", "mark", "meta", "neg", "nmod", 'npadvmod', "nsubj", "nsubjpass", "nummod", 
        "oprd", "parataxis", "pcomp", "pobj", "poss", "preconj", "predet", "prep", "prt", "punct", 
        "quantmod", "relcl", "xcomp", ""]

In [None]:
tags = ["$", "''", ",", "-LRB-", "-RRB-", ".", ":", "ADD", "AFX", "CC", "CD", "DT", "EX", 
        "FW", "HYPH", "IN", "JJ", "JJR", "JJS", "LS", "MD", "NFP", "NN", "NNP", "NNPS", 
        "NNS", "PDT", "POS", "PRP", "PRP$", "RB", "RBR", "RBS", "RP", "SYM", "TO", "UH", "VB", "VBD", 
        "VBG", "VBN", "VBP", "VBZ", "WDT", "WP", "WP$", "WRB", "XX", "``", "_SP"]

In [None]:
cols = ["aliceblue", "antiquewhite", "antiquewhite1", "antiquewhite2", "antiquewhite3",
"antiquewhite4", "aqua", "aquamarine", "aquamarine1", "aquamarine2",
"aquamarine3", "aquamarine4", "azure", "azure1", "azure2",
"azure3", "azure4", "beige", "bisque", "bisque1",
"bisque2", "bisque3", "bisque4", "black", "blanchedalmond",
"blue", "blue1", "blue2", "blue3", "blue4",
"blueviolet", "brown", "brown1", "brown2", "brown3",
"brown4", "burlywood", "burlywood1", "burlywood2", "burlywood3",
"burlywood4", "cadetblue", "cadetblue1", "cadetblue2", "cadetblue3",
"cadetblue4", "chartreuse", "chartreuse1", "chartreuse2", "chartreuse3",
"chartreuse4", "chocolate", "chocolate1", "chocolate2", "chocolate3",
"chocolate4", "coral", "coral1", "coral2", "coral3",
"coral4", "cornflowerblue", "cornsilk", "cornsilk1", "cornsilk2",
"cornsilk3", "cornsilk4", "crimson", "cyan", "cyan1",
"cyan2", "cyan3", "cyan4", "darkblue", "darkcyan",
"darkgoldenrod", "darkgoldenrod1", "darkgoldenrod2", "darkgoldenrod3", "darkgoldenrod4",
"darkgray", "darkgreen", "darkgrey", "darkkhaki", "darkmagenta",
"darkolivegreen", "darkolivegreen1", "darkolivegreen2", "darkolivegreen3", "darkolivegreen4",
"darkorange", "darkorange1", "darkorange2", "darkorange3", "darkorange4",
"darkorchid", "darkorchid1", "darkorchid2", "darkorchid3", "darkorchid4",
"darkred", "darksalmon", "darkseagreen", "darkseagreen1", "darkseagreen2",
"darkseagreen3", "darkseagreen4", "darkslateblue", "darkslategray", "darkslategray1",
"darkslategray2", "darkslategray3", "darkslategray4", "darkslategrey", "darkturquoise",
"darkviolet", "deeppink", "deeppink1", "deeppink2", "deeppink3",
"deeppink4", "deepskyblue", "deepskyblue1", "deepskyblue2", "deepskyblue3",
"deepskyblue4", "dimgray", "dimgrey", "dodgerblue", "dodgerblue1",
"dodgerblue2", "dodgerblue3", "dodgerblue4", "firebrick", "firebrick1",
"firebrick2", "firebrick3", "firebrick4", "floralwhite", "forestgreen",
"fuchsia", "gainsboro", "ghostwhite", "gold", "gold1",
"gold2", "gold3", "gold4", "goldenrod", "goldenrod1",
"goldenrod2", "goldenrod3", "goldenrod4", "gray"]

In [None]:
colors = random.choices(cols, k=len(deps))
dep2color = {}
for i,dep in enumerate(deps):
    dep2color[dep] = colors[i]

In [None]:
colors = random.choices(cols, k=len(tags))
tag2color = {}
for i,tag in enumerate(tags):
    tag2color[tag] = colors[i]

In [None]:
# Override node attributes to customise the plot
# https://graphviz.gitlab.io/_pages/doc/info/attrs.html
def customize_plot(doc):
    for token in doc:
        token._.plot['label'] = "  " * len(token.orth_)
        token._.plot['color'] = tag2color[token.tag_]
    return doc

In [None]:
nlp = spacy.load("en_core_web_sm")

In [None]:
out_folder = '/tmp/images'
os.makedirs(out_folder, exist_ok=True)
path = Path(out_folder)

In [None]:
def convert_text_to_image(text, path, identifier, scale=0.1, size=512):
    doc_folder = path/identifier
    os.makedirs(str(doc_folder), exist_ok=True)
    doc = nlp(text)
    sentence_spans = list(doc.sents)
    for i, sent in enumerate(sentence_spans):
        doc = nlp(str(sent))
        doc = customize_plot(doc)
        png = create_png(doc)
        fname = f'{str(i)}.png'
        save_path = doc_folder/fname
        with open(str(save_path), 'wb') as f:
            f.write(png)

In [None]:
for i in tqdm(range(len(df))):
    text = df.excerpt.loc[i]
    identifier = df['id'].loc[i]
    convert_text_to_image(text, path, identifier)

# Model Inference

In [None]:
IMG_SIZE = 448

In [None]:
class BagOfImagesModel(Module):
    def __init__(self, encoder):
        self.encoder = encoder
        self.bn1 = nn.BatchNorm1d(4096, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.d1 = nn.Dropout(p=0.25, inplace=False)
        self.l1 = nn.Linear(in_features=4096, out_features=512, bias=False)
        self.bn2 = nn.BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        self.d2 = nn.Dropout(p=0.5, inplace=False)
        self.l2 = nn.Linear(in_features=512, out_features=1, bias=False)

    def forward(self, imgs):
        b,n,ch,h,w = imgs.shape
        unrolled = imgs.reshape(-1,ch,h,w)
        ftrs = self.encoder(unrolled).squeeze()
        num_ftrs = ftrs.shape[-1]
        ftrs = ftrs.reshape(b,n,num_ftrs)        
        ftrs_max = torch.max(ftrs, 1, keepdim=True)[0].squeeze()
        ftrs_mean = torch.mean(ftrs, 1, keepdim=True).squeeze()
        if b == 1: # error with batch size 1 being squeezed out above
            ftrs_max = ftrs_max[None, ...]
            ftrs_mean = ftrs_mean[None, ...]
        ftrs_cat = torch.cat([ftrs_max, ftrs_mean], 1)
        x = self.bn1(ftrs_cat)
        x = self.d1(x)
        x = self.l1(x)
        x = F.relu(x)
        x = self.bn2(x)
        x = self.d2(x)
        out = self.l2(x) 
        return out

In [None]:
aug = albumentations.Compose([
        albumentations.LongestMaxSize(max_size=IMG_SIZE, p=1.0),
        albumentations.PadIfNeeded(min_height=IMG_SIZE, min_width=IMG_SIZE, border_mode=0, value=0., p=1.0),
        albumentations.Normalize(p=1.0)],
    p=1.)

In [None]:
# for test, we will use all sentences / images for each example with batch size = 1

class TestImageBagDataset(torch.utils.data.Dataset):
    def __init__(self, df, path, aug):
        self.df = df
        self.path = path
        self.aug = aug
        
    def __getitem__(self, i):
        image_id = self.df['id'].loc[i]
        target = torch.tensor(0, dtype=torch.float)
        img_folder = self.path/image_id
        num_imgs = len(img_folder.ls())
        img_paths = [self.path/f'{image_id}/{i}.png' for i in range(num_imgs)]
        imgs = [self._open_img(x) for x in img_paths]
        imgs = torch.stack(imgs)
        return (imgs, target)
    
    def __len__(self): 
        return len(self.df)
    
    def _open_img(self, x):
        img = cv2.imread(str(x), cv2.IMREAD_UNCHANGED)[...,:3]
        img = self.aug(image=img)['image']
        img = torch.tensor(img, dtype=torch.float)
        img = img.permute(2,0,1)
        return img

In [None]:
from matplotlib import pyplot as plt
def visualize(image):
    plt.figure(figsize=(10, 10))
    plt.axis('off')
    plt.imshow(image)

dataset = TestImageBagDataset(df, path, aug)
visualize(dataset[1][0][0].permute(1,2,0))

In [None]:
all_preds = []
for k in range(5):
    test_ds = TestImageBagDataset(df, path, aug)
    test_dls = DataLoaders.from_dsets(test_ds, test_ds, bs=1).cuda() # a hack to get dataloaders, there is probably a better way
    encoder = create_body(resnet50, pretrained=False, cut=-1)
    net = BagOfImagesModel(encoder)
    net = net.cuda()
    learn = Learner(test_dls, net, loss_func=MSELossFlat(), metrics=rmse, model_dir=".").to_fp16()
    learn = learn.load(f'../input/commonlit-cv-models-resnet50/model/model_{k}')
    preds, _ = learn.get_preds()
    all_preds.append(preds)

In [None]:
preds = torch.stack(all_preds).mean(0).cpu().numpy()

In [None]:
sub = pd.read_csv('../input/commonlitreadabilityprize/sample_submission.csv')
sub.target = preds
sub.to_csv('submission.csv', index=False)