#### Модели CNN, SepCNN, LSTM на базе предобученных статических эмбеддингов

In [1]:
import torch
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

In [2]:
from utils import *

In [3]:
from google.colab import drive

In [4]:
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

In [5]:
from torch.utils.data import random_split
from torch.utils.data import Dataset, DataLoader

In [6]:
def set_seed(seed):
    np.random.seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.random.manual_seed(seed)
    torch.cuda.random.manual_seed_all(seed)

def get_available_device():
    cuda = torch.cuda.is_available()
    return torch.device('cuda' if cuda else 'cpu')

In [7]:
def emb(data):
    emb = torch.nn.Embedding(*data.shape)
    emb.weight.data.copy_(torch.tensor(data))
    return emb

# def emb(num_embeddings, embedding_dim):
#     return torch.nn.Embedding(num_embeddings, embedding_dim)

def conv(in_channels, out_channels, kernel_size):
    return torch.nn.Sequential(
        torch.nn.Conv1d(in_channels, out_channels, kernel_size), 
        torch.nn.ReLU(), 
        torch.nn.AdaptiveMaxPool1d(1), 
        torch.nn.Flatten()
    )

def sep_conv(in_channels, out_channels, kernel_size):
    return torch.nn.Sequential(
        torch.nn.Conv1d(in_channels, in_channels, kernel_size, groups=in_channels), 
        torch.nn.Conv1d(in_channels, out_channels, 1),
        torch.nn.ReLU(), 
        torch.nn.AdaptiveMaxPool1d(1), 
        torch.nn.Flatten()
    )

def lstm(input_size, hidden_size):
    return torch.nn.LSTM(input_size, hidden_size, batch_first=True)

def mlp(in_features, out_features):
    return torch.nn.Sequential(
        torch.nn.Linear(in_features, out_features // 2), 
        torch.nn.Dropout(0.2), 
        torch.nn.Linear(out_features // 2, out_features)
    )

In [8]:
class CNN(torch.nn.Module):
    
    def __init__(self, emb_list):
        super().__init__()  
        self.emb = emb(emb_list)
        self.conv2 = conv(300, 100, 2)
        self.conv3 = conv(300, 100, 3)
        self.conv4 = conv(300, 100, 4)
        self.mlp = mlp(300, 6)
        
    def forward(self, x):
        x = self.emb(x)
        x = x.permute(0,2,1)
        x2 = self.conv2(x)
        x3 = self.conv3(x)
        x4 = self.conv4(x)
        x = torch.concat((x2,x3,x4), -1)
        return self.mlp(x)

# model = CNN()
# token_indices = [[1,5,7,8],[8,4,2,9]]
# model(torch.tensor(token_indices)).shape

In [9]:
class SepCNN(torch.nn.Module):
    
    def __init__(self, emb_list):
        super().__init__()  
        self.emb = emb(emb_list)
        self.conv2 = sep_conv(300, 100, 2)
        self.conv3 = sep_conv(300, 100, 3)
        self.conv4 = sep_conv(300, 100, 4)
        self.mlp = mlp(300, 6)

    def forward(self, x):
        x = self.emb(x)
        x = x.permute(0,2,1)
        x2 = self.conv2(x)
        x3 = self.conv3(x)
        x4 = self.conv4(x)
        x = torch.concat((x2,x3,x4), -1)
        return self.mlp(x)

# model = SepCNN()
# token_indices = [[1,5,7,8],[8,4,2,9]]
# model(torch.tensor(token_indices)).shape

In [10]:
class LSTM(torch.nn.Module):
    
    def __init__(self, emb_list):
        super().__init__()  
        self.emb = emb(emb_list)
        self.lstm = lstm(300, 300)
        self.mlp = mlp(300, 6)
    
    def forward(self, x):
        x = self.emb(x)
        x, _ = self.lstm(x)
        return self.mlp(x[:,-1,:])

# model = LSTM()
# token_indices = [[1,5,7,3],[8,4,2,4]]
# model(torch.tensor(token_indices)).shape

In [11]:
class MyDataset(Dataset):
    
    def __init__(self, df, w2i):
        super(Dataset).__init__()
        self.df, self.w2i = df, w2i
    
    def tokenize(self, text):
        return [self.w2i[w] for w in text.split() if w in self.w2i]
    
    def __getitem__(self, idx):
        s = df.iloc[idx]['title'] + '. ' + df.iloc[idx]['text']
        ids = torch.tensor(self.tokenize(s), dtype=torch.long)
        topic = torch.tensor(df.iloc[idx]['topic'], dtype=torch.long)
        return ids, topic

    def __len__(self):
        return len(self.df)

# dataset = MyDataset(df, w2i)
# dataset[333]

In [12]:
def pad(X):
    max_len = len(max(X, key=len))
    max_len = max(4, max_len) # convs
    l = lambda x: torch.tensor([len(x)])
    f = lambda to_pad: torch.full((to_pad,), w2i['PAD'])
    p = lambda x: torch.cat((x, f(max_len-len(x)), l(x)))
    return torch.stack([p(x) for x in X])

def collate_fn(batch):
    X, y = zip(*batch)
    return pad(X), torch.stack(y)

# loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
# batch = next(iter(loader))
# batch[0].shape, batch[1].shape

In [None]:
def metrics_stat(y_true, y_pred):
    return {'acc': accuracy_score(y_true, y_pred),
            'f1': f1_score(y_true, y_pred, average='macro')}

In [13]:
def nn_epoch(model, optimizer, loss_fn, 
             loader, device, train=False):
    
    model = model.to(device)
    torch.set_grad_enabled(train)
    (model.eval, model.train)[int(train)]()
    
    pbar, total_loss = tqdm(loader), 0
    preds = torch.empty(0, dtype=torch.int8)
    targets = torch.empty(0, dtype=torch.int8)
  
    for input_ids, labels in pbar:
        
        input_ids = input_ids.to(device)
        labels = labels.to(device).long()
        
        if train:
            optimizer.zero_grad()
        
        logits = model(input_ids)
        loss = loss_fn(logits, labels)
        
        if train: 
            loss.backward()
            optimizer.step()
        
        targets = torch.cat((targets, labels.cpu()))
        preds = torch.cat((preds, torch.argmax(logits.cpu(), 1)))
        pbar.set_description(f'loss: {loss.item():.4}')
        total_loss += loss.item()
    
    avg_loss = total_loss / len(loader)
    pbar.set_description(f'loss: {avg_loss:.4}')
    return {'preds': preds, 'targets': targets, 'loss': avg_loss}

In [16]:
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!unzip /content/drive/MyDrive/shared/topic/data.zip -d data

In [32]:
set_seed(42)
device = get_available_device()

In [33]:
df = pd.read_feather('data/prep.ftr')
w2v_embs = load_pickle('data/w2v_embs.bin')
w2v_embs['PAD'] = np.zeros((300,), dtype=np.float32)
w2i = {k: i for i, k in enumerate(w2v_embs.keys())}
embs = np.stack(list(w2v_embs.values()))

In [34]:
dataset = MyDataset(df, w2i)
train_size = int(0.800 * len(dataset))
valid_size = len(dataset) - train_size
generator = torch.Generator().manual_seed(42)

In [35]:
train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size], generator)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)

In [36]:
model = LSTM(embs)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=w2i['PAD'])

In [37]:
num_epochs = 6

for epoch in range(num_epochs):
    
    res = nn_epoch(
        model, optimizer, loss_fn, 
        train_loader, device, train=True
    )

    print('train', metrics_stat(res['targets'], res['preds']))
    
    res = nn_epoch(
        model, optimizer, loss_fn, 
        valid_loader, device, train=False
    )
    
    print('valid', metrics_stat(res['targets'], res['preds']))
    torch.save(model.state_dict(), f'data/model_{epoch+1:02}.pt')

  0%|          | 0/10804 [00:00<?, ?it/s]

train {'acc': 0.7237523356646594, 'f1': 0.6813519418742978}


  0%|          | 0/2701 [00:00<?, ?it/s]

valid {'acc': 0.9162925768233987, 'f1': 0.9072378524811892}


  0%|          | 0/10804 [00:00<?, ?it/s]

train {'acc': 0.8656305860710505, 'f1': 0.8478341207096922}


  0%|          | 0/2701 [00:00<?, ?it/s]

valid {'acc': 0.9247269529803777, 'f1': 0.918194441888978}


  0%|          | 0/10804 [00:00<?, ?it/s]

train {'acc': 0.883196519787346, 'f1': 0.8701781140706388}


  0%|          | 0/2701 [00:00<?, ?it/s]

valid {'acc': 0.9241600333209923, 'f1': 0.9180976052807391}


  0%|          | 0/10804 [00:00<?, ?it/s]

train {'acc': 0.889834724608505, 'f1': 0.8779597227444235}


  0%|          | 0/2701 [00:00<?, ?it/s]

valid {'acc': 0.9317035357275083, 'f1': 0.9251920543586255}


  0%|          | 0/10804 [00:00<?, ?it/s]

train {'acc': 0.8958510496751763, 'f1': 0.8842492340791209}


  0%|          | 0/2701 [00:00<?, ?it/s]

valid {'acc': 0.9338439466864125, 'f1': 0.9276628103892647}


  0%|          | 0/10804 [00:00<?, ?it/s]

train {'acc': 0.9009446787340264, 'f1': 0.8896533347622763}


  0%|          | 0/2701 [00:00<?, ?it/s]

valid {'acc': 0.9355215660866346, 'f1': 0.9295253870754764}


In [None]:
# CNN acc 0.929 after 5 epochs
# SepCNN acc 0.924 after 5 epochs
# LSTM acc 0.936 after 6 epochs

In [38]:
res = nn_epoch(model, None, loss_fn, valid_loader, device, train=False)

  0%|          | 0/2701 [00:00<?, ?it/s]

In [39]:
accuracy_score(res['targets'], res['preds'])

0.9355215660866346

In [40]:
f1_score(res['targets'], res['preds'], average='macro')

0.9295253870754764

In [41]:
print(confusion_matrix(res['targets'], res['preds']))

[[ 7956   284   367   284    33   138]
 [  271 10077   253    80    22     7]
 [  433   222 25905   338    54   301]
 [  312    94   502  9670     5   106]
 [   24   164   127    10 12527     1]
 [  366    57   513   190    15 14724]]


In [42]:
print(classification_report(res['targets'], res['preds']))

              precision    recall  f1-score   support

           0       0.85      0.88      0.86      9062
           1       0.92      0.94      0.93     10710
           2       0.94      0.95      0.94     27253
           3       0.91      0.90      0.91     10689
           4       0.99      0.97      0.98     12853
           5       0.96      0.93      0.95     15865

    accuracy                           0.94     86432
   macro avg       0.93      0.93      0.93     86432
weighted avg       0.94      0.94      0.94     86432



In [43]:
list(load_json('data/labels.json').keys())

['Интернет и СМИ', 'Культура', 'Мир', 'Наука и техника', 'Спорт', 'Экономика']