In [1]:
from collections import Counter, defaultdict
from typing import Tuple, Union
import math
import os
import numpy as np
import pandas as pd
import re
import spacy
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer
import torch
from torch import nn

## ScriptLoader

In [9]:
label2id = defaultdict(int)
for i, label in enumerate("SNCDTEM"):
    label2id[label] = i + 1

class ScriptLoader:

    def __init__(self, scripts: np.ndarray, features: torch.FloatTensor, labels: torch.IntTensor, \
        batch_size: int, shuffle: bool=False) -> None:
        self.scripts = scripts
        self.features = features
        self.labels = labels
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __len__(self) -> int:
        return math.ceil(len(self.scripts)/self.batch_size)
    
    def __iter__(self):
        if self.shuffle:
            index = np.random.permutation(len(self.scripts))
            self.scripts = self.scripts[index]
            self.features = self.features[index]
            self.labels = self.labels[index]
        self.i = 0
        return self
    
    def __next__(self) -> Tuple[np.ndarray, torch.FloatTensor, torch.IntTensor]:
        if self.i < len(self):
            batch_index = slice(self.i * self.batch_size, (self.i + 1) * self.batch_size)
            batch_scripts = self.scripts[batch_index]
            batch_features = self.features[batch_index]
            batch_labels = self.labels[batch_index]
            return batch_scripts, batch_features, batch_labels
        else:
            raise StopIteration
    
def get_dataloaders(results_folder: str, seqlen: int, train_batch_size: int, eval_batch_size: int, \
    device: torch.device) -> Tuple[ScriptLoader, ScriptLoader, ScriptLoader]:
    df = pd.read_csv(os.path.join(results_folder, "seq_{}.csv".format(seqlen)), index_col=None)
    feats_df = pd.read_csv(os.path.join(results_folder, "feats.csv"), index_col=0)
    features_file = os.path.join(results_folder, "seq_{}_feats.pt".format(seqlen))
    scripts, features, labels = [], [], []

    if os.path.exists(features_file):
        features = torch.load(features_file).float()
    else:
        features = []
        for _, row in tqdm(df.iterrows(), total=len(df)):
            feature = [feats_df.loc[row["line_{}".format(i + 1)]] for i in range(seqlen)]
            features.append(feature)
        features = torch.FloatTensor(features)
        torch.save(features, features_file)
        
    scripts = df[["line_{}".format(i + 1) for i in range(seqlen)]].values
    labels = df["label"].apply(lambda labelseq: [label2id[label] for label in labelseq])
    
    scripts = np.array(scripts)
    features = features.to(device)
    labels = torch.IntTensor(labels).to(device)
    print("scripts : {}, features = {}, labels = {}".format(scripts.shape, features.shape, labels.shape))

    train_index = (df["split"] == "train").values
    test_index = (df["split"] == "test").values
    dev_index = (df["split"] == "dev").values
    
    train_loader = ScriptLoader(scripts[train_index], features[train_index], labels[train_index], train_batch_size, \
        shuffle=True)
    test_loader = ScriptLoader(scripts[test_index], features[test_index], labels[test_index], eval_batch_size)
    dev_loader = ScriptLoader(scripts[dev_index], features[dev_index], labels[dev_index], eval_batch_size)

    return train_loader, test_loader, dev_loader

In [10]:
train_loader, test_loader, dev_loader = get_dataloaders("/workspace/mica-text-robust-script-parser/results", 10, \
    64, 256, "cuda:0")

scripts : (60048, 10), features = torch.Size([60048, 10, 38]), labels = torch.Size([60048, 10])


In [11]:
for scripts, features, labels in train_loader:break

In [12]:
labels

tensor([[0, 3, 7, 4, 4, 3, 7, 4, 4, 3],
        [2, 2, 2, 2, 2, 1, 2, 2, 2, 2],
        [3, 4, 7, 3, 4, 3, 4, 1, 2, 2],
        [3, 4, 4, 4, 1, 2, 2, 2, 3, 4],
        [4, 3, 4, 4, 4, 3, 4, 0, 4, 3],
        [4, 3, 4, 3, 4, 3, 4, 3, 0, 4],
        [2, 3, 4, 4, 0, 2, 3, 4, 4, 4],
        [2, 3, 4, 0, 1, 2, 2, 2, 3, 4],
        [2, 2, 2, 2, 3, 4, 3, 4, 4, 2],
        [3, 4, 3, 7, 4, 2, 3, 7, 4, 3],
        [4, 4, 4, 1, 2, 2, 2, 3, 4, 4],
        [3, 4, 4, 4, 3, 4, 3, 4, 4, 3],
        [2, 2, 3, 6, 4, 4, 4, 4, 2, 3],
        [4, 3, 4, 4, 3, 4, 4, 4, 4, 4],
        [4, 4, 4, 1, 2, 2, 7, 3, 6, 4],
        [4, 6, 4, 2, 3, 4, 1, 2, 2, 1],
        [2, 2, 2, 3, 4, 4, 3, 4, 3, 4],
        [3, 4, 4, 4, 1, 2, 2, 3, 4, 4],
        [3, 4, 3, 4, 3, 4, 4, 3, 4, 4],
        [2, 3, 4, 4, 2, 2, 2, 1, 2, 2],
        [2, 3, 4, 4, 2, 2, 2, 2, 2, 3],
        [4, 2, 2, 2, 2, 2, 2, 2, 0, 3],
        [2, 3, 6, 4, 3, 4, 4, 6, 4, 3],
        [2, 2, 2, 2, 2, 1, 2, 2, 3, 4],
        [2, 2, 3, 4, 4, 3, 4, 1, 2, 2],


In [14]:
Counter(labels.flatten().tolist())

Counter({0: 17, 3: 130, 7: 17, 4: 268, 2: 168, 1: 19, 6: 20, 5: 1})

In [18]:
r = 1/np.bincount(labels.flatten().tolist())

In [19]:
type(r)

numpy.ndarray

In [20]:
r/r.sum()

array([0.04752822, 0.04252525, 0.0048094 , 0.00621523, 0.00301485,
       0.80797982, 0.04039899, 0.04752822])

## ScriptParser

In [21]:
class ScriptParser(nn.Module):

    def __init__(self, n_features: int, n_labels: int) -> None:
        super().__init__()
        self.encoder = SentenceTransformer("all-mpnet-base-v2")
        self.feature_size = self.encoder.get_sentence_embedding_dimension() + n_features
        self.hidden_size = 256
        self.n_labels = n_labels
        self.lstm = nn.LSTM(self.feature_size, self.hidden_size, batch_first=True)
        self.classifier = nn.Linear(self.hidden_size, self.n_labels)
    
    def forward(self, scripts: np.ndarray, features: torch.FloatTensor, labels: torch.IntTensor = None) -> \
        Union[torch.LongTensor, Tuple[torch.Tensor, torch.LongTensor]]:
        batch_size, seqlen = scripts.shape
        device = next(self.parameters()).device
        script_embeddings = self.encoder.encode(scripts.flatten(), convert_to_tensor=True, device=device)\
            .reshape(batch_size, seqlen, -1)
        input = torch.cat([script_embeddings, features], dim=2)
        output, _ = self.lstm(input)
        logits = self.classifier(output)
        pred = logits.argmax(dim=2)
        if labels is None:
            return pred
        else:
            ce_loss = nn.CrossEntropyLoss()
            loss = ce_loss(logits.reshape(-1, self.n_labels), labels.flatten())
            return loss, pred

In [22]:
scriptparser = ScriptParser(38, 8)

In [32]:
iterable = iter(parameter for name, parameter in scriptparser.named_parameters() if not name.startswith("encoder"))

## Cross-validation

In [2]:
data_df = pd.read_csv("/workspace/mica-text-robust-script-parser/results/data.csv", index_col=None)

In [3]:
data_df.head()

Unnamed: 0,movie,line_no,text,label,error
0,44_inch_chest,1,out to wear ... whatever combination - it,D,NONE
1,44_inch_chest,2,works! - You look superb! ... And your,D,NONE
2,44_inch_chest,3,underw ear - immac ulate ! 100 % cot ton!,D,NONE
3,44_inch_chest,4,Dazzlin'!... Not like my pinky grey-y,D,NONE
4,44_inch_chest,5,"things! Nah, you've just got it - good at",D,NONE


In [10]:
data_df.groupby(["movie", "error"]).agg({"line_no": len}).describe()

Unnamed: 0,line_no
count,351.0
mean,256.923077
std,16.47118
min,233.0
25%,250.0
50%,251.0
75%,251.0
max,320.0


In [11]:
torch.zeros((10, 50), device="cuda:0")

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,