In [25]:
import re
import math
import ast
import random
import builtins
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from collections import Counter
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error,r2_score,f1_score
from functools import partial

In [6]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import math
import ast
from torch.utils.data import DataLoader, TensorDataset
from collections import OrderedDict

# Constants
VOCAB_SIZE = 10000
MAX_LEN = 641
EMBEDDING_DIM = 256
NUM_HEADS = 8
FF_DIM = 256
NUM_TRANSFORMER_BLOCKS = 10
DROPOUT_RATE = 0.1
PAD_TOKEN = '<PAD>'
UNK_TOKEN = '<UNK>'

# Load and prepare test data
df = pd.read_csv("classification.csv")
essay = {'tokens': [], 'labels': []}
for _, row in df.iterrows():
    ess = ast.literal_eval(row['tokens'])
    lab = ast.literal_eval(row['labels'])
    essay['tokens'].append(ess)
    essay['labels'].append(lab)
df = pd.DataFrame(essay)

# Label mappings (these must match training)
all_labels_list = [label for sublist in df['labels'] for label in sublist]
label_list = sorted(list(set(all_labels_list)))
label2id = {label: i for i, label in enumerate(label_list)}
id2label = {i: label for label, i in label2id.items()}
PAD_LABEL_ID = label2id.get('O', 0)

# Token vocab
all_tokens = [token.lower() for sublist in df['tokens'] for token in sublist]
word_counts = pd.Series(all_tokens).value_counts()
vocab = word_counts.head(VOCAB_SIZE - 2).index.tolist()
word2id = {word: i + 2 for i, word in enumerate(vocab)}
word2id[PAD_TOKEN] = 0
word2id[UNK_TOKEN] = 1
id2word = {i: word for word, i in word2id.items()}
ACTUAL_VOCAB_SIZE = len(word2id)
PAD_TOKEN_ID = word2id[PAD_TOKEN]

# Prepare input
X = [[word2id.get(token.lower(), word2id[UNK_TOKEN]) for token in seq] for seq in df['tokens']]
y = [[label2id[label] for label in seq] for seq in df['labels']]
X_padded = np.array([seq[:MAX_LEN] + [PAD_TOKEN_ID] * (MAX_LEN - len(seq)) if len(seq) < MAX_LEN else seq[:MAX_LEN] for seq in X])
y_padded = np.array([seq[:MAX_LEN] + [PAD_LABEL_ID] * (MAX_LEN - len(seq)) if len(seq) < MAX_LEN else seq[:MAX_LEN] for seq in y])

# Device
device = torch.device("cpu")

# Model definition (must match training)
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate):
        super().__init__()
        self.att = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout_rate, batch_first=True)
        self.ffn = nn.Sequential(nn.Linear(embed_dim, ff_dim), nn.ReLU(), nn.Linear(ff_dim, embed_dim))
        self.layernorm1 = nn.LayerNorm(embed_dim)
        self.layernorm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)

    def forward(self, x, padding_mask=None):
        x_norm = self.layernorm1(x)
        attn_output, _ = self.att(x_norm, x_norm, x_norm, key_padding_mask=padding_mask)
        x = x + self.dropout1(attn_output)
        x_norm = self.layernorm2(x)
        x = x + self.dropout2(self.ffn(x_norm))
        return x

class TokenClassifierTransformer(nn.Module):
    def __init__(self, num_transformer_blocks, embed_dim, num_heads, ff_dim, vocab_size, num_labels, max_len, dropout_rate, padding_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
        self.pos_encoder = PositionalEncoding(embed_dim, max_len)
        self.dropout = nn.Dropout(dropout_rate)
        self.encoder_blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, ff_dim, dropout_rate)
            for _ in range(num_transformer_blocks)
        ])
        self.classifier = nn.Linear(embed_dim, num_labels)
        self.padding_idx = padding_idx

    def forward(self, input_ids):
        padding_mask = (input_ids == self.padding_idx)
        x = self.embedding(input_ids) * math.sqrt(self.embedding.embedding_dim)
        x = self.pos_encoder(x)
        x = self.dropout(x)
        for block in self.encoder_blocks:
            x = block(x, padding_mask)
        return self.classifier(x)

# Load model
# model = TokenClassifierTransformer(
#     num_transformer_blocks=NUM_TRANSFORMER_BLOCKS,
#     embed_dim=EMBEDDING_DIM,
#     num_heads=NUM_HEADS,
#     ff_dim=FF_DIM,
#     vocab_size=ACTUAL_VOCAB_SIZE,
#     num_labels=len(label2id),
#     max_len=MAX_LEN,
#     dropout_rate=DROPOUT_RATE,
#     padding_idx=PAD_TOKEN_ID
# ).to(device)

# model.load_state_dict(torch.load("ak.pth"),map_location='cpu')
# model.eval()


device = torch.device("cpu") # You already defined this, which is good

model = TokenClassifierTransformer(
    num_transformer_blocks=NUM_TRANSFORMER_BLOCKS,
    embed_dim=EMBEDDING_DIM,
    num_heads=NUM_HEADS,
    ff_dim=FF_DIM,
    vocab_size=ACTUAL_VOCAB_SIZE,
    num_labels=len(label2id),
    max_len=MAX_LEN,
    dropout_rate=DROPOUT_RATE,
    padding_idx=PAD_TOKEN_ID
).to(device) # Ensure model is created on the target device first

checkpoint = torch.load("classification.pth", map_location=device)

if isinstance(checkpoint, OrderedDict) or isinstance(checkpoint, dict):
    if 'model_state_dict' in checkpoint: # Check for common nesting
        state_dict = checkpoint['model_state_dict']
    elif 'state_dict' in checkpoint: # Another common key
        state_dict = checkpoint['state_dict']
    else:
        state_dict = checkpoint
    model.load_state_dict(state_dict)
elif isinstance(checkpoint, nn.Module):
     print("Warning: Loaded entire model object, extracting state_dict.")
     model.load_state_dict(checkpoint.state_dict())
else:
    raise TypeError(f"Loaded checkpoint is of unexpected type: {type(checkpoint)}")


model.eval() # Set model to evaluation mode


# Predict on one sample
idx = 4
input_ids = torch.LongTensor(X_padded[idx:idx+1]).to(device)
gt_labels = y_padded[idx]

with torch.no_grad():
    logits = model(input_ids)
pred_ids = torch.argmax(logits[0], dim=-1).cpu().numpy()
input_ids_np = input_ids[0].cpu().numpy()

original_tokens = []
predicted_labels = []
ground_truth_labels = []

for token_id, pred_id, gt_id in zip(input_ids_np, pred_ids, gt_labels):
    if token_id == PAD_TOKEN_ID:
        break
    original_tokens.append(id2word.get(token_id, '<UNK>'))
    predicted_labels.append(id2label.get(pred_id, 'UNK'))
    ground_truth_labels.append(id2label.get(gt_id, 'UNK'))

print("Original Tokens:   ", original_tokens)
print("Predicted Labels:  ", (predicted_labels))
print("Ground Truth Labels:",( ground_truth_labels))


Original Tokens:    ['i', 'know', 'that', 'as', 'teenagers', 'we', 'all', 'have', 'the', 'desire', 'to', 'have', 'cell', 'phones', 'and', 'use', 'them', '.', 'i', 'think', 'that', 'we', 'should', "n't", 'be', 'able', 'to', 'have', 'cell', 'phones', 'in', 'school', '.', 'i', 'believe', 'this', 'because', 'cell', 'phones', 'are', 'only', 'objects', 'we', 'do', "n't", 'need', 'to', 'have', 'we', 'just', 'want', 'to', '.', 'i', 'wo', 'nt', 'lie', 'i', 'love', 'my', 'cell', 'phone', 'and', 'like', 'to', 'talk', 'and', 'text', 'to', 'my', 'friends', 'all', 'the', 'time', 'but', 'i', 'think', 'that', 'we', 'should', 'have', 'a', 'limit', 'to', 'how', 'much', 'use', 'we', 'get', 'out', 'of', 'them', '.', 'if', 'we', 'bring', 'them', 'to', 'school', 'it', 'will', 'only', 'increase', 'the', 'amount', 'of', 'people', 'who', 'just', 'slack', 'off', 'and', 'do', "n't", 'do', 'work', '.', 'my', 'father', 'told', 'me', 'that', 'having', 'a', 'cell', 'phone', 'is', 'a', 'privilege', 'not', 'a', 'given

In [7]:
class BertConfig:
    def __init__(self,vocab_size=30522,hidden_size=768,num_hidden_layers=12,num_attention_heads=12,intermediate_size=3072,hidden_act="gelu",hidden_dropout_prob=0.1,attention_probs_dropout_prob=0.1,max_position_embeddings=512,type_vocab_size=2,layer_norm_eps=1e-12,pad_token_id=0):
        self.vocab_size=vocab_size
        self.hidden_size=hidden_size
        self.num_hidden_layers=num_hidden_layers
        self.num_attention_heads=num_attention_heads
        self.intermediate_size=intermediate_size
        self.hidden_act=hidden_act
        self.hidden_dropout_prob=hidden_dropout_prob
        self.attention_probs_dropout_prob=attention_probs_dropout_prob
        self.max_position_embeddings=max_position_embeddings
        self.type_vocab_size=type_vocab_size
        self.layer_norm_eps=layer_norm_eps
        self.pad_token_id=pad_token_id
        if num_attention_heads>0 and hidden_size%num_attention_heads==0:self.attention_head_size=hidden_size//num_attention_heads
        else:self.attention_head_size=hidden_size

In [8]:
class BertEmbeddings(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.word_embeddings=nn.Embedding(config.vocab_size,config.hidden_size,padding_idx=config.pad_token_id)
        self.position_embeddings=nn.Embedding(config.max_position_embeddings,config.hidden_size)
        self.token_type_embeddings=nn.Embedding(config.type_vocab_size,config.hidden_size)
        self.LayerNorm=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps)
        self.dropout=nn.Dropout(config.hidden_dropout_prob)
        self.register_buffer("position_ids",torch.arange(config.max_position_embeddings).expand((1,-1)))
        self.position_embedding_type=getattr(config,"position_embedding_type","absolute")


    def forward(self,input_ids,token_type_ids=None):
        seq_length=input_ids.size(1)
        if token_type_ids is None: token_type_ids=torch.zeros_like(input_ids)
        word_embeds=self.word_embeddings(input_ids)
        token_type_embeds=self.token_type_embeddings(token_type_ids)
        position_ids=self.position_ids[:,:seq_length].to(input_ids.device)
        position_embeds=self.position_embeddings(position_ids)
        embeddings=word_embeds+position_embeds+token_type_embeds
        embeddings=self.LayerNorm(embeddings)
        embeddings=self.dropout(embeddings)
        return embeddings

In [9]:
class BertSelfAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.num_attention_heads=config.num_attention_heads
        self.attention_head_size=config.attention_head_size
        self.all_head_size=self.num_attention_heads*self.attention_head_size
        self.query=nn.Linear(config.hidden_size,self.all_head_size)
        self.key=nn.Linear(config.hidden_size,self.all_head_size)
        self.value=nn.Linear(config.hidden_size,self.all_head_size)
        self.dropout=nn.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self,x):
        new_x_shape=x.size()[:-1]+(self.num_attention_heads,self.attention_head_size)
        x=x.view(*new_x_shape)
        return x.permute(0,2,1,3)

    def forward(self,hidden_states,attention_mask=None):
        query_layer=self.transpose_for_scores(self.query(hidden_states))
        key_layer=self.transpose_for_scores(self.key(hidden_states))
        value_layer=self.transpose_for_scores(self.value(hidden_states))
        attention_scores=torch.matmul(query_layer,key_layer.transpose(-1,-2))/math.sqrt(self.attention_head_size)
        if attention_mask is not None:
          attention_scores=attention_scores+attention_mask
        attention_probs=nn.functional.softmax(attention_scores,dim=-1)
        attention_probs=self.dropout(attention_probs)
        context_layer=torch.matmul(attention_probs,value_layer)
        context_layer=context_layer.permute(0,2,1,3).contiguous()
        new_context_layer_shape=context_layer.size()[:-2]+(self.all_head_size,)
        context_layer=context_layer.view(*new_context_layer_shape)
        return context_layer

In [10]:
class BertAttentionOutput(nn.Module):
    def __init__(self,config):
        super().__init__(); self.dense=nn.Linear(config.hidden_size,config.hidden_size)
        self.LayerNorm=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps)
        self.dropout=nn.Dropout(config.hidden_dropout_prob)

    def forward(self,hidden_states,input_tensor):
        hidden_states=self.dropout(self.dense(hidden_states))
        hidden_states=self.LayerNorm(hidden_states+input_tensor)
        return hidden_states

In [11]:
class BertAttention(nn.Module):
    def __init__(self,config):
      super().__init__()
      self.self=BertSelfAttention(config)
      self.output=BertAttentionOutput(config)
    def forward(self,hidden_states,attention_mask=None):
      return self.output(self.self(hidden_states,attention_mask),hidden_states)

In [12]:
def get_activation(name):
    if name == "gelu": return nn.GELU()
    if name == "relu": return nn.ReLU()

In [13]:
class BertIntermediate(nn.Module):
    def __init__(self,config):
      super().__init__()
      self.dense=nn.Linear(config.hidden_size,config.intermediate_size)
      self.intermediate_act_fn=get_activation(config.hidden_act)
    def forward(self,hidden_states):
      return self.intermediate_act_fn(self.dense(hidden_states))

In [14]:
class BertOutput(nn.Module):
    def __init__(self,config):
      super().__init__()
      self.dense=nn.Linear(config.intermediate_size,config.hidden_size)
      self.LayerNorm=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps)
      self.dropout=nn.Dropout(config.hidden_dropout_prob)
    def forward(self,hidden_states,input_tensor):
      hidden_states=self.dropout(self.dense(hidden_states))
      hidden_states=self.LayerNorm(hidden_states+input_tensor)
      return hidden_states

In [15]:
class BertLayer(nn.Module):
    def __init__(self,config):
      super().__init__()
      self.attention=BertAttention(config)
      self.intermediate=BertIntermediate(config)
      self.output=BertOutput(config)
    def forward(self,hidden_states,attention_mask=None):
      attention_output=self.attention(hidden_states,attention_mask)
      intermediate_output=self.intermediate(attention_output)
      layer_output=self.output(intermediate_output,attention_output)
      return layer_output

In [16]:
class BertEncoder(nn.Module):
    def __init__(self,config):
      super().__init__()
      self.layer=nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
    def forward(self,hidden_states,attention_mask=None):
        for layer_module in self.layer:
          hidden_states=layer_module(hidden_states,attention_mask)
        return hidden_states

In [17]:
class BertModel(nn.Module):
    def __init__(self,config):
      super().__init__()
      self.config=config
      self.embeddings=BertEmbeddings(config)
      self.encoder=BertEncoder(config)
    def get_input_embeddings(self):
      return self.embeddings.word_embeddings
    def set_input_embeddings(self,value):
      self.embeddings.word_embeddings=value
    def _prune_heads(self,heads_to_prune):
      raise NotImplementedError
    def forward(self,input_ids,attention_mask=None,token_type_ids=None):
        if attention_mask is None:
          attention_mask=torch.ones_like(input_ids)
        if token_type_ids is None:
          token_type_ids=torch.zeros_like(input_ids)
        extended_attention_mask=attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask=extended_attention_mask.to(dtype=next(self.parameters()).dtype)
        extended_attention_mask=(1.0-extended_attention_mask)*-10000.0
        embedding_output=self.embeddings(input_ids=input_ids,token_type_ids=token_type_ids)
        sequence_output=self.encoder(embedding_output,attention_mask=extended_attention_mask)
        return {"last_hidden_state": sequence_output}

In [18]:
class SimpleTokenizer:
    def __init__(self,vocab_limit=5000):
        self.word_to_id={}
        self.id_to_word={}
        self.vocab_limit=vocab_limit
        self.pad_token="[PAD]"
        self.unk_token="[UNK]"
        self.cls_token="[CLS]"
        self.sep_token="[SEP]"
        self.special_tokens=[self.pad_token,self.unk_token,self.cls_token,self.sep_token]
        for i,token in enumerate(self.special_tokens):
            self.word_to_id[token]=i
            self.id_to_word[i]=token
        self.vocab_built=False

    def _basic_tokenize(self,text):
        text=text.lower()
        tokens=re.findall(r"[\w']+|[.,!?;]",text)
        return tokens

    def build_vocab(self,corpus_records):
        if self.vocab_built: return
        print(f"Building vocabulary from {len(corpus_records)} records...")
        word_counts=Counter()
        for record in tqdm(corpus_records,desc="Processing Corpus"):
            segments=record.get("segments",{})
            for dt in segments:
                for text_segment in segments[dt]:
                    tokens=self._basic_tokenize(text_segment)
                    word_counts.update(tokens)
        current_id=len(self.special_tokens)
        sorted_words=sorted(word_counts.items(),key=lambda item: item[1],reverse=True)
        for word,count in tqdm(sorted_words,desc="Adding words to vocab"):
            if word not in self.word_to_id:
                if current_id >= self.vocab_limit: break
                self.word_to_id[word]=current_id
                self.id_to_word[current_id]=word
                current_id += 1
        self.vocab_size=len(self.word_to_id)
        self.pad_token_id=self.word_to_id[self.pad_token]
        self.unk_token_id=self.word_to_id[self.unk_token]
        self.cls_token_id=self.word_to_id[self.cls_token]
        self.sep_token_id=self.word_to_id[self.sep_token]
        self.vocab_built=True
        print(f"Vocabulary building complete. Size: {self.vocab_size}")

    def tokenize(self,text):
        return self._basic_tokenize(text)

    def convert_tokens_to_ids(self,tokens):
        return [self.word_to_id.get(token,self.unk_token_id) for token in tokens]

    def convert_ids_to_tokens(self,ids):
        return [self.id_to_word.get(id,self.unk_token) for id in ids]

    def _pad(self,ids,max_length):
        padding_len=max_length-len(ids)
        return ids+[self.pad_token_id]*padding_len if padding_len>0 else ids

    def _truncate(self,ids,max_length):
        effective_max_len=max_length-2
        return ids[:effective_max_len] if len(ids)>effective_max_len else ids

    def encode_plus(self,text,add_special_tokens=True,max_length=None,padding='max_length',truncation=True,return_tensors=None,return_attention_mask=True):
        tokens=self.tokenize(text)
        if truncation and max_length is not None:
          tokens=self._truncate(tokens,max_length)
        ids=self.convert_tokens_to_ids(tokens)
        if add_special_tokens:
          ids=[self.cls_token_id]+ids+[self.sep_token_id]
        attention_mask=[1]*len(ids)
        if padding=='max_length' and max_length is not None:
            pad_len=max_length-len(ids)
            if pad_len>0:
              ids=ids+[self.pad_token_id]*pad_len; attention_mask=attention_mask+[0]*pad_len
            elif pad_len < 0:
              ids=ids[:max_length]; attention_mask=attention_mask[:max_length]
        output={"input_ids":ids}
        if return_attention_mask:
          output["attention_mask"]=attention_mask
        if return_tensors=="pt":
            for key in output: output[key]=torch.tensor(output[key])
        return output

    def __call__(self,text_batch,**kwargs):
         if isinstance(text_batch,builtins.str):
             return self.encode_plus(text_batch,**kwargs)
         elif isinstance(text_batch,builtins.list):
             batch_outputs=[self.encode_plus(text,**kwargs) for text in text_batch]
             if not batch_outputs:
                 print("Warning: Empty batch encountered in SimpleTokenizer.__call__")
                 return {}
             collated={}
             keys=batch_outputs[0].keys()
             for key in keys:
                 if kwargs.get("return_tensors") == "pt":
                     items_to_stack=[item[key] for item in batch_outputs if isinstance(item.get(key),torch.Tensor)]
                     if len(items_to_stack) == len(batch_outputs):
                         collated[key]=torch.stack(items_to_stack)
                     else:
                         print(f"Warning: Could not stack tensors for key '{key}' due to inconsistent types.")
                         collated[key]=[item.get(key) for item in batch_outputs]
                 else:
                     collated[key]=[item[key] for item in batch_outputs]
             return collated
         else:
             raise TypeError(f"Input must be a string or a list of strings,got {type(text_batch)}")


In [26]:
df=pd.read_csv(rf'classification.csv')
df1=pd.read_csv(rf'score.csv')
df1.set_index('essay_id', inplace=True)
df1.drop(columns=['Unnamed: 0'], inplace=True)
import ast
essay=[]
mx=0
for i in df.iterrows():
    dic={"segments":{"lead":[],"position":[],"claim":[],"rebuttal":[],"evidence":[],"concluding":[],"counterclaim":[]},"score":[]}
    row=list((df1.loc[i[1]['essay_id']]).to_numpy())
    ess=ast.literal_eval(i[1]['tokens'])
    lab=ast.literal_eval(i[1]['labels'])
    assert (len(ess) == len(lab))
    j=0
    while j < len(ess):
        str=""
        cur=lab[j]
        lenght=0
        while j < len(ess) and lab[j] == cur:
            if str!="":
                str+=" "
            str+=ess[j]
            lenght+=1
            j+=1
        mx=max(mx,lenght)
        if cur == "O":
            continue
        dic['segments'][cur].append(str)
    dic['score']=row
    essay.append(dic)
essay_data_list=essay
print(f"Successfully processed {len(essay_data_list)} essays.")
print(f"Maximum segment token length found: {mx}")
MAX_LEN_COLLATE=128
discourse_types=["lead", "position", "claim", "rebuttal", "evidence", "concluding", "counterclaim"]
score_columns=df1.columns.tolist()
print(f"Score columns being used: {score_columns}")
MAX_LEN_COLLATE=128

Successfully processed 1000 essays.
Maximum segment token length found: 641
Score columns being used: ['lead', 'position', 'claim', 'rebuttal', 'evidence', 'concluding', 'counterclaim']


In [27]:
class EssayDataset(Dataset):
    def __init__(self, records, tokenizer): self.records = records
    def __len__(self): return len(self.records)
    def __getitem__(self, idx):
        rec = self.records[idx]; scores = [float(s) for s in rec["score"]]
        return {"segments": rec["segments"], "labels": torch.tensor(scores, dtype=torch.float)}

In [28]:
def collate_fn(batch,tokenizer_instance,max_len):
    labels=torch.stack([b["labels"] for b in batch],dim=0)
    batched_text={dt: [] for dt in discourse_types}
    for b in batch:
        for dt in discourse_types:
            segs=b["segments"].get(dt,[])
            batched_text[dt].append(" ".join(segs) if segs else tokenizer_instance.pad_token)
    tokenized={}
    for dt in discourse_types:
        tokenized[dt]=tokenizer_instance(batched_text[dt],padding='max_length',truncation=True,max_length=max_len,return_tensors="pt",add_special_tokens=True,return_attention_mask=True)
    return tokenized,labels

In [29]:
class MoEInteraction(nn.Module):
    def __init__(self,d_model=768,n_experts=len(discourse_types),k=2):
        super().__init__()
        self.gate=nn.Linear(d_model,n_experts)
        self.experts=nn.ModuleList([nn.Sequential(nn.Linear(d_model,d_model*4),nn.GELU(),nn.Linear(d_model*4,d_model)) for _ in range(n_experts)])
        self.k=k
        self.n_experts=n_experts

    def forward(self,E):
        if E.shape[0] != self.n_experts:
          return torch.zeros_like(E)
        gate_logits=self.gate(E)
        topk_vals,topk_inds=torch.topk(gate_logits,self.k,dim=-1)
        out=torch.zeros_like(E)
        for i in range(E.size(0)):
            weights=torch.softmax(topk_vals[i],dim=-1)
            expert_outputs=[]
            for expert_index,weight in zip(topk_inds[i],weights):
                if expert_index < self.n_experts: expert_outputs.append(weight.unsqueeze(0)*self.experts[expert_index](E[i]))
            if expert_outputs: out[i]=torch.stack(expert_outputs).sum(dim=0)
        return out

In [30]:
class HierarchicalMoEScorer(nn.Module):
    def __init__(self,bert_scratch_config,n_experts=len(discourse_types),k=2,num_labels=len(score_columns)):
        super().__init__()
        self.bert=BertModel(bert_scratch_config)
        bert_hidden_size=bert_scratch_config.hidden_size
        self.moe=MoEInteraction(d_model=bert_hidden_size,n_experts=n_experts,k=k)
        self.heads=nn.ModuleList([nn.Linear(bert_hidden_size,1) for _ in range(num_labels)])
        self.discourse_types=discourse_types; self.num_labels=num_labels
        self.config=bert_scratch_config

    def forward(self,tokenized_batch):
        B=tokenized_batch[self.discourse_types[0]]["input_ids"].size(0)
        device=next(self.parameters()).device
        bert_hidden_size=self.config.hidden_size
        disc_embeddings=[]
        for dt in self.discourse_types:
            ids=tokenized_batch[dt]["input_ids"].to(device)
            mask=tokenized_batch[dt]["attention_mask"].to(device)
            last_hidden_state=self.bert(input_ids=ids,attention_mask=mask)["last_hidden_state"]
            disc_embeddings.append(last_hidden_state[:,0,:])

        E=torch.stack(disc_embeddings,dim=1)
        H_all=[]
        for i in range(B):
          H_all.append(self.moe(E[i]))
        H=torch.stack(H_all,dim=0)
        pooled_H=H.mean(dim=1)
        outs=[self.heads[j](pooled_H).squeeze(-1) for j in range(self.num_labels)]
        final_preds=torch.stack(outs,dim=1)
        return final_preds

In [31]:
train_records,val_records=train_test_split(essay_data_list,test_size=0.2,random_state=42)
print(f"Data split: {len(train_records)} train,{len(val_records)} validation records.")

tokenizer=SimpleTokenizer(vocab_limit=12000)
tokenizer.build_vocab(train_records)
print(f"Vocab Size: {tokenizer.vocab_size}")

train_data=EssayDataset(train_records,tokenizer)
val_data=EssayDataset(val_records,tokenizer)
collate_fn_with_args=partial(collate_fn,tokenizer_instance=tokenizer,max_len=MAX_LEN_COLLATE)
train_loader=DataLoader(train_data,batch_size=4,shuffle=True,collate_fn=collate_fn_with_args,num_workers=0)
val_loader=DataLoader(val_data,batch_size=4,shuffle=False,collate_fn=collate_fn_with_args,num_workers=0)


bert_config=BertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=768,
    num_hidden_layers=6,
    num_attention_heads=12,
    intermediate_size=3072,
    max_position_embeddings=MAX_LEN_COLLATE,
    pad_token_id=tokenizer.pad_token_id
  )
print(f" BERT Config: vocab={bert_config.vocab_size},hidden={bert_config.hidden_size},layers={bert_config.num_hidden_layers},pad_id={bert_config.pad_token_id}")
model=HierarchicalMoEScorer(bert_scratch_config=bert_config,n_experts=len(discourse_types),k=2,num_labels=len(score_columns) )

Data split: 800 train,200 validation records.
Building vocabulary from 800 records...


Processing Corpus:   0%|          | 0/800 [00:00<?, ?it/s]

Adding words to vocab:   0%|          | 0/10556 [00:00<?, ?it/s]

Vocabulary building complete. Size: 10560
Vocab Size: 10560
 BERT Config: vocab=10560,hidden=768,layers=6,pad_id=0
