In [2]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader

class MetaTrollDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'label': torch.tensor(label)
        }

class MetaTrollClassifier(nn.Module):
    def __init__(self, model_name='xlm-roberta-base', num_labels=2):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name)
        self.adapter = nn.Sequential(
            nn.Linear(self.model.config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_labels)
        )
        
    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]  # CLS token
        return self.adapter(pooled_output)

class MetaTrollTrainer:
    def __init__(self, model, device='cuda'):
        self.model = model
        self.device = device
        self.model.to(device)
        
    def create_episode(self, support_set, query_set, k_shot=5):
        """Create a few-shot learning episode"""
        # Randomly sample k examples per class for support set
        support_data = []
        support_labels = []
        for label in [0, 1]:  # Binary classification
            indices = (support_set['labels'] == label).nonzero()[0]
            selected = np.random.choice(indices, k_shot, replace=False)
            support_data.extend(support_set['texts'][selected])
            support_labels.extend([label] * k_shot)
            
        # Rest goes to query set
        query_data = query_set['texts']
        query_labels = query_set['labels']
        
        return {
            'support': (support_data, support_labels),
            'query': (query_data, query_labels)
        }
    
    def train_episode(self, episode, optimizer, criterion):
        self.model.train()
        
        # Train on support set
        support_data, support_labels = episode['support']
        support_dataset = MetaTrollDataset(support_data, support_labels, self.tokenizer)
        support_loader = DataLoader(support_dataset, batch_size=len(support_dataset))
        
        for batch in support_loader:
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            labels = batch['label'].to(self.device)
            
            outputs = self.model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
        # Evaluate on query set
        self.model.eval()
        query_data, query_labels = episode['query']
        query_dataset = MetaTrollDataset(query_data, query_labels, self.tokenizer)
        query_loader = DataLoader(query_dataset, batch_size=32)
        
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in query_loader:
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['label'].to(self.device)
                
                outputs = self.model(input_ids, attention_mask)
                _, predicted = torch.max(outputs, 1)
                
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        return correct / total

In [4]:
import sys
sys.path.append('../src')
from src.data_tools.czech_data_tools import load_czech_media_data
import pandas as pd

# Load all comments data once
print("Loading all comments data...")
all_comments_df = load_czech_media_data("./data/MediaSource")

# Load annotations
print("Loading annotations...")
annotations_df = pd.read_csv('../workspace/annotations/user_labels.csv')
certain_df = annotations_df[annotations_df['label'].isin([0, 1])]

# Prepare data
data = {
    'texts': [],
    'labels': []
}

# Extract comments for each annotated user from the already loaded dataframe
print("Extracting comments for annotated users...")
for _, row in certain_df.iterrows():
    user_comments = all_comments_df[all_comments_df['author'] == row['author']]['text'].tolist()
    if user_comments:  # Only add if user has comments
        data['texts'].append(' '.join(user_comments))
        data['labels'].append(row['label'])
        print(f"Found {len(user_comments)} comments for user {row['author']}")
    else:
        print(f"No comments found for user {row['author']}")

print(f"\nFinal dataset: {len(data['texts'])} users with comments")

Loading all comments data...


Loading files: 100%|██████████████████████████████████████████████████████| 124/124 [00:10<00:00, 11.30it/s]


Loading annotations...
Extracting comments for annotated users...
Found 39 comments for user Štěpán Malák
Found 64 comments for user Jan Benda
Found 22 comments for user Jindra Macek
Found 40 comments for user Josef Fortelný
Found 37 comments for user Michal Musil
Found 82 comments for user Pavel Rehberger
Found 48 comments for user Vladimír Kalinay
Found 42 comments for user Petr Jelinek
Found 739 comments for user Jan Sykora
Found 20 comments for user Radek Palán
Found 64 comments for user Jan Trejbal
Found 99 comments for user Michal Antonín
Found 33 comments for user Gabi Muller
Found 347 comments for user Michal Žák
Found 25 comments for user Ivan Penzes
Found 102 comments for user Richard Benes
Found 55 comments for user Martin Ondík
Found 71 comments for user Jan Kozohorský
Found 21 comments for user Petr Mojžíš
Found 44 comments for user Tomáš Souček
Found 48 comments for user Libor Weizenbauer
Found 196 comments for user Jan Velebil
Found 60 comments for user Jakub Ručka
Found

In [7]:
# After loading and processing comments
import numpy as np

# Convert lists to numpy arrays for easier handling
data = {
    'texts': np.array(data['texts']),
    'labels': np.array(data['labels'])
}

# Initialize model and trainer
model = MetaTrollClassifier(model_name='xlm-roberta-base')
trainer = MetaTrollTrainer(model)
tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
trainer.tokenizer = tokenizer  # Add tokenizer to trainer

# Training loop
num_episodes = 100
k_shot = 5
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

print(f"Starting training with {len(data['texts'])} examples")
print(f"Label distribution: {np.bincount(data['labels'])}")

for episode in range(num_episodes):
    # Create episode
    episode_data = trainer.create_episode(
        support_set={'texts': data['texts'], 'labels': data['labels']},
        query_set={'texts': data['texts'], 'labels': data['labels']},
        k_shot=k_shot
    )
    
    # Train on episode
    accuracy = trainer.train_episode(episode_data, optimizer, criterion)
    print(f"Episode {episode + 1}, Accuracy: {accuracy:.4f}")

Starting training with 34 examples
Label distribution: [17 17]
Episode 1, Accuracy: 0.5000
Episode 2, Accuracy: 0.5000
Episode 3, Accuracy: 0.5882
Episode 4, Accuracy: 0.5294
Episode 5, Accuracy: 0.5000
Episode 6, Accuracy: 0.5000
Episode 7, Accuracy: 0.5000
Episode 8, Accuracy: 0.5000
Episode 9, Accuracy: 0.5000
Episode 10, Accuracy: 0.5000
Episode 11, Accuracy: 0.5000
Episode 12, Accuracy: 0.5000
Episode 13, Accuracy: 0.5000
Episode 14, Accuracy: 0.5000
Episode 15, Accuracy: 0.5000
Episode 16, Accuracy: 0.5000
Episode 17, Accuracy: 0.5000
Episode 18, Accuracy: 0.5000
Episode 19, Accuracy: 0.5000
Episode 20, Accuracy: 0.5000


KeyboardInterrupt: 