In [1]:
!pip install pyterrier
!pip install unidecode
!pip install sentence_transformers
!pip install torch_scatter

Collecting pyterrier
  Downloading pyterrier-0.1.5-py2.py3-none-any.whl.metadata (9.3 kB)
Downloading pyterrier-0.1.5-py2.py3-none-any.whl (22 kB)
Installing collected packages: pyterrier
Successfully installed pyterrier-0.1.5
Collecting unidecode
  Downloading Unidecode-1.3.8-py3-none-any.whl.metadata (13 kB)
Downloading Unidecode-1.3.8-py3-none-any.whl (235 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m235.5/235.5 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: unidecode
Successfully installed unidecode-1.3.8
Collecting sentence_transformers
  Downloading sentence_transformers-3.3.1-py3-none-any.whl.metadata (10 kB)
Downloading sentence_transformers-3.3.1-py3-none-any.whl (268 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sentence_transformers
Successfully installed sentence_transformers-3.3.1
Colle

In [2]:
from dataclasses import dataclass, asdict, field
import os
from typing import Any, Dict, Optional
import yaml
import re
import string
import datetime
import ast
from os.path import join as join_path
from collections import *
import argparse
import copy
import glob
from functools import partial
from statistics import mean, median, StatisticsError
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup
from torch import nn
from transformers import AutoConfig, AutoModel, AutoTokenizer, T5Config
import collections
from torch.optim import Optimizer
from sentence_transformers import SentenceTransformer
import itertools
from nltk.tokenize import sent_tokenize
import numpy as np
from torch_scatter import scatter
import json
from typing import *
import math
import random
import statistics
from scipy.stats import norm
from argparse import Namespace
import logging
import os
import shutil
import string
import pandas as pd
import pyterrier as pt
from unidecode import unidecode

@dataclass
class Model:
    name: str = 'sentence-transformers/all-mpnet-base-v2'
    pooling: str = 'default'  # 'mean', 'default'


@dataclass
class Train:
    device: str = 'cuda' 
    dtype: str = 'float32'  # float32, float16, bfloat16
    batch_size: int = 8 
    num_epochs: int = 20
    metrics_window_size: int = 128
    warmup_steps: int = 400
    evaluation_steps: int = -1  # If -1, after each epoch
    label_smoothing: float = 0.  # [0, 1]
    ema: bool = False  # Exponential moving average
    freezing: bool = False  # Gradually makes early modules untrainable
    output_path: Optional[str] = os.path.join('/kaggle/working/')  # If None, no model checkpointing every evaluation_steps
    save_vectors: bool = False  # Whether to save vectors for evaluation
    save_models: bool = True # Whether to save models
    p_max_seq_length: int = 512
    fc_max_seq_length: int = 256

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

@dataclass
class Optimizer:
    name: str = 'adamw'  # adamw, shampoo_fb, shampoo_google
    lr: float = 5e-6
    weight_decay: float = 0.005 
    clip_value: float = 1.0  # if None, no gradient clipping
    sam: bool = False
    sam_step: float = 1.  # Will be ignored if sam is False, propability of sam step
    sam_rho: float = 0.05  # if sam_step=10 then rho=0.5,

    
@dataclass
class Config:
    model: Model = Model()
    train: Train = Train()
    optimizer: Optimizer = Optimizer()

    seed: int = 3407

    def __init__(self, path: Optional[str] = None):
        self.timestamp = datetime.datetime.now().strftime('%d-%m-%Y-%H-%M-%S')
        if path is not None: 
            with open(path, 'r') as f: self.init_class(yaml.load(f, Loader=yaml.SafeLoader))
        
    def to_dict(self):
        return asdict(self)
    
    def init_class(self, d):
        for name in dir(self):
            if name.startswith('_') or name.endswith('_') or name not in d:
                continue
            attr = getattr(self, name)
            if isinstance(d[name], dict):
                for k, v in d[name].items():
                    setattr(attr, k, v)
            else: 
                setattr(self, name, d[name])

In [3]:
"""
Library of various cleaning-related functions, regular expressions and variables.
"""


simple_latin = string.ascii_lowercase + string.ascii_uppercase
dirty_chars = string.digits + string.punctuation


def is_clean_text(text: str) -> bool:
    """
    Simple text cleaning method.
    """
    dirty = (
        len(text) < 25                                               # Short text
        or
        0.5 < sum(char in dirty_chars for char in text) / len(text)  # More than 50% dirty chars                                            
    )
    return not dirty


# Source: https://gist.github.com/dperini/729294
url_regex = re.compile(
    r'(?:^|(?<![\w\/\.]))'
    r'(?:(?:https?:\/\/|ftp:\/\/|www\d{0,3}\.))'
    r'(?:\S+(?::\S*)?@)?' r'(?:'
    r'(?!(?:10|127)(?:\.\d{1,3}){3})'
    r'(?!(?:169\.254|192\.168)(?:\.\d{1,3}){2})'
    r'(?!172\.(?:1[6-9]|2\d|3[0-1])(?:\.\d{1,3}){2})'
    r'(?:[1-9]\d?|1\d\d|2[01]\d|22[0-3])'
    r'(?:\.(?:1?\d{1,2}|2[0-4]\d|25[0-5])){2}'
    r'(?:\.(?:[1-9]\d?|1\d\d|2[0-4]\d|25[0-4]))'
    r'|'
    r'(?:(?:[a-z\\u00a1-\\uffff0-9]-?)*[a-z\\u00a1-\\uffff0-9]+)'
    r'(?:\.(?:[a-z\\u00a1-\\uffff0-9]-?)*[a-z\\u00a1-\\uffff0-9]+)*'
    r'(?:\.(?:[a-z\\u00a1-\\uffff]{2,}))' r'|' r'(?:(localhost))' r')'
    r'(?::\d{2,5})?'
    r'(?:\/[^\)\]\}\s]*)?',
    flags=re.IGNORECASE,
)


def remove_urls(text: str) -> str:
    return url_regex.sub('', text)


# Source: https://gist.github.com/Nikitha2309/15337f4f593c4a21fb0965804755c41d
emoji_regex = re.compile('['
        u'\U0001F600-\U0001F64F'  # emoticons
        u'\U0001F300-\U0001F5FF'  # symbols & pictographs
        u'\U0001F680-\U0001F6FF'  # transport & map symbols
        u'\U0001F1E0-\U0001F1FF'  # flags (iOS)
        u'\U00002500-\U00002BEF'  # chinese char
        u'\U00002702-\U000027B0'
        u'\U00002702-\U000027B0'
        u'\U000024C2-\U0001F251'
        u'\U0001f926-\U0001f937'
        u'\U00010000-\U0010ffff'
        u'\u2640-\u2642'
        u'\u2600-\u2B55'
        u'\u200d'
        u'\u23cf'
        u'\u23e9'
        u'\u231a'
        u'\ufe0f'  # dingbats
        u'\u3030'
    ']+')


def remove_emojis(text: str) -> str:
    return emoji_regex.sub('', text)


sentence_stop_regex = re.compile('['
    u'\u002e' # full stop
    u'\u2026' # ellipsis
    u'\u061F' # arabic question mark
    u'\u06D4' # arabic full stop
    u'\u2022' # bullet point
    u'\u3002' # chinese period
    u'\u25CB' # white circle
    '\|'      # pipe
']+')


def replace_stops(text: str) -> str:
    """
    Replaces some characters that are being used to end sentences. Used for sentence segmentation with sliding windows.
    """
    return sentence_stop_regex.sub('.', text)


whitespace_regex = re.compile(r'\s+')


def replace_whitespaces(text: str) -> str:
    return whitespace_regex.sub(' ', text)


def clean_ocr(ocr: str) -> str:
    """
    Remove all lines that are shorter than 6 and have more than 50% `dirty_chars`.
    """
    return '\n'.join(
        line
        for line in ocr.split('\n')
        if len(line) > 5 and sum(char in dirty_chars for char in line) / len(line) < 0.5
    )


def clean_twitter_picture_links(text):
    """
    Replaces links to picture in twitter post only with 'pic'. 
    """
    return re.sub(r'pic.twitter.com/\S+', 'pic', text)


def clean_twitter_links(text):
    """
    Replaces twitter links with 't.co'.
    """
    return re.sub(r'\S+//t.co/\S+', 't.co', text)


def remove_elongation(text):
    """
    Replaces any occurrence of a string of consecutive identical non-space 
    characters (at least three in a row) with just one instance of that character.
    """
    text = re.sub(r'(\S+)\1{2,}', r'\1', text)
    return text

def safe_literal_eval(value):
    try:
        return ast.literal_eval(str(value))
    except (ValueError, SyntaxError):
        return value  # Or `None`, depending on how you want to handle it

In [4]:
"""
Library of custom types used for datasets. Some basic functions over these types are also included.
"""

Id2FactCheck = Dict[int, str]
Id2Post = Dict[int, str]
FactCheckPostMapping = List[Tuple[int, int]]

Language = str  # ISO 639-3 language code
LanguageDistribution = Dict[Language, float]  # Confidence for language identification. Should sum to [0, 1].

OriginalText = str
EnglishTranslation = str
TranslatedText = Tuple[OriginalText, EnglishTranslation, LanguageDistribution]

Instance = Tuple[Optional[datetime.datetime], str]  # When and where was a fact-check or post published


def is_in_distribution(language: Language, distribution: LanguageDistribution, threshold: float = 0.2) -> bool:
    """
    Check whether `language` is in a `distribution` with more than `treshold` x 100%
    """
    return next(
        (
            percentage >= threshold
            for distribution_language, percentage in distribution
            if distribution_language == language
        ),
        False
    )


def combine_distributions(texts: Iterable[TranslatedText]) -> LanguageDistribution:
    """
    Combine `LanguageDistribution`s from multiple `TranslatedText`s taking the length of the text into consideration.
    """
    total_length = sum(len(text[0]) for text in texts)
    distribution = defaultdict(lambda: 0)
    for original_text, _, text_distribution in texts:
        for language, percentage in text_distribution:
            distribution[language] += percentage * len(original_text) / total_length
    return list(distribution.items())

In [5]:
class Dataset:
    """Dataset
    
    Abstract class for datasets. Subclasses should implement `load` function that load `id_to_fact_check`, `id_to_post`,
    and `fact_check_post_mapping` object attributes. The class also implemenets basic cleaning methods that might be
    reused.
    
    Attributes:
        clean_ocr: bool = True  Should cleaning of OCRs be performed
        remove_emojis: bool = False  Should emojis be removed from the texts?
        remove_urls: bool = True  Should URLs be removed from the texts?
        replace_whitespaces: bool = True  Should whitespaces be replaced by a single space whitespace?
        clean_twitter: bool = True  
        remove_elongation: bool = False  Should occurrence of a string of consecutive identical non-space 
    characters (at least three in a row) with just one instance of that character?

        After `load` is called, following attributes are accesible:
                fact_check_post_mapping: list[tuple[int, int]]  List of Factcheck-Post id pairs.
                id_to_fact_check: dict[int, str]  Factcheck id -> Factcheck text
                id_to_post: dict[int, str]  Post id -> Post text
                
    Methods:
        clean_text: Performs text cleaning based on initialization attributes.
        maybe_clean_ocr: Perform OCR-specific text cleaning, if `self.clean_ocr`
        load: Abstract method. To be implemented by the subclasses.
        
    """
    
    # The default values here are based on our preliminary experiments. Might not be the best for all cases.
    def __init__(
        self,
        clean_ocr: bool = True,
        dataset: str = None,  # Here to read and discard the `dataset` field from the argparser
        remove_emojis: bool = True,
        remove_urls: bool = True,
        replace_whitespaces: bool = True,
        clean_twitter: bool = True,
        remove_elongation: bool = False
    ):
        self.clean_ocr = clean_ocr
        self.remove_emojis = remove_emojis
        self.remove_urls = remove_urls
        self.replace_whitespaces = replace_whitespaces
        self.clean_twitter = clean_twitter
        self.remove_elongation = remove_elongation
        
        
    def __len__(self):
        return len(self.fact_check_post_mapping)

    
    def __getitem__(self, idx):
        p_id, fc_id = self.fact_check_post_mapping[idx]
        return self.id_to_fact_check[fc_id], self.id_to_post[p_id]

        
    def clean_text(self, text):
        
        if self.remove_urls:
            text = remove_urls(text)

        if self.remove_emojis:
            text = remove_emojis(text)

        if self.replace_whitespaces:
            text = replace_whitespaces(text)
        
        if self.clean_twitter:
            text = clean_twitter_picture_links(text)
            text = clean_twitter_links(text)
        
        if self.remove_elongation:
            text = remove_elongation(text)

        return text.strip()        
        
        
    def maybe_clean_ocr(self, ocr):
        if self.clean_ocr:
            return clean_ocr(ocr)
        return ocr
        
    
    def __getattr__(self, name):
        if name in {'id_to_fact_check', 'id_to_post', 'fact_check_post_mapping'}:
            raise AttributeError(f"You have to `load` the dataset first before using '{name}'")
        raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")

        
    def load(self):
        raise NotImplementedError
        


In [6]:

class OurDataset(Dataset):
    """Our dataset
    
    Class for our dataset that can load different variants. It requires data from `drive/datasets/ours`.

    Initialization Attributes:
        crosslingual: bool  If `True`, only crosslingual pairs (fact-check and post in different languages) are loaded.
        fact_check_fields: Iterable[str]  List of fields used to generate the final `str` representation for fact-checks. Supports `claim` and `title`.
        fact_check_language: Optional[Language]  If a `Language` is specified, only fact-checks with that language are selected.
        language: Optional[Language]  If a `Language` is specified, only fact-checks and posts with that language are selected.
        post_language: Optional[Language]  If a `Language` is specified, only posts with that language are selected.
        split: `train`, `test` or `dev`. `None` means all the samples.
        version: 'original' or 'english'. Language version of the dataset.
        
        Also check `Dataset` attributes.
        
        After `load` is called, following attributes are accesible:
            fact_check_post_mapping: list[tuple[int, int]]  List of Factcheck-Post id pairs.
            id_to_fact_check: dict[int, str]  Factcheck id -> Factcheck text
            id_to_post: dict[int, str]  Post id -> Post text
        

    Methods:
        load: Loads the data from the csv files. Populates `id_to_fact_check`, `id_to_post` and `fact_check_post_mapping` attributes.
    """
        
    our_dataset_path = join_path('/semeval-data/SemEval_Task7/')
    csvs_loaded = False

    
    def __init__(
        self,
        crosslingual: bool = False,
        fact_check_fields: Iterable[str] = ('claim', 'title'),
        fact_check_language: Optional[Language] = None,
        language: Optional[Language] = None,
        post_language: Optional[Language] = None,
        split: Optional[str] = None,
        # version: str = 'original',
        fold: int = 0,
        **kwargs
    ):
        super().__init__(**kwargs)
        
        assert all(field in ('claim', 'title') for field in fact_check_fields)
        assert split in ('test', 'train')
        
        # self.crosslingual = crosslingual
        self.split = split
        self.fold = fold
        
    @classmethod
    def maybe_load_csvs(cls):
        """
        Load the csvs and store them as class variables. When individual objects are initialized, they can reuse the same
        pre-loaded dataframes without costly text parsing.
        
        `OurDataset.csvs_loaded` is a flag indicating whether the csvs are already loaded.
        """
        
        if cls.csvs_loaded:
            return
        
        posts_path = join_path(cls.our_dataset_path, 'posts.csv')
        fact_checks_path = join_path(cls.our_dataset_path, 'fact_checks.csv')
        fact_check_post_mapping_path = join_path(cls.our_dataset_path, 'pairs.csv')
        
        for path in [posts_path, fact_checks_path, fact_check_post_mapping_path]:
            assert os.path.isfile(path)

        
        parse_col = lambda s: ast.literal_eval(s.replace('\n', '\\n')) if s else s
        
        print('Loading fact-checks.')
        cls.df_fact_checks = pd.read_csv(fact_checks_path).fillna('').set_index('fact_check_id')
        for col in ['claim', 'instances', 'title']:
            cls.df_fact_checks[col] = cls.df_fact_checks[col].apply(parse_col)
        print(f'{len(cls.df_fact_checks)} loaded.')

            
        print('Loading posts.')
        cls.df_posts = pd.read_csv(posts_path).fillna('').set_index('post_id')
        for col in ['instances', 'ocr', 'verdicts', 'text']:
            cls.df_posts[col] = cls.df_posts[col].apply(parse_col)
        print(f'{len(cls.df_posts)} loaded.')

         
        print('Loading fact-check-post mapping.')
        cls.df_fact_check_post_mapping = pd.read_csv(fact_check_post_mapping_path) 
        print(f'{len(cls.df_fact_check_post_mapping)} loaded.')
        
        cls.csvs_loaded = True
        

    def load(self):
        
        self.maybe_load_csvs()
        
        df_posts = self.df_posts.copy()
        df_fact_checks = self.df_fact_checks.copy()
        df_fact_check_post_mapping = self.df_fact_check_post_mapping.copy()

        df_posts['text'] = df_posts['text'].apply(lambda x: safe_literal_eval(x))
        df_posts['ocr'] = df_posts['ocr'].apply(lambda x: safe_literal_eval(x))
        df_fact_checks['claim'] = df_fact_checks['claim'].apply(lambda x: safe_literal_eval(x))
        df_fact_checks['title'] = df_fact_checks['title'].apply(lambda x: safe_literal_eval(x))

        train_fold = pd.read_csv(f'/semeval-data/train_fold_task_7/fold_{self.fold}_train.csv')
        test_fold = pd.read_csv(f'/semeval-data/train_fold_task_7/fold_{self.fold}_test.csv')
        # full_data = pd.concat([train_fold, test_fold]).reset_index(drop = True)
      
        if self.split:
            fold_data = pd.read_csv(f'/semeval-data/train_fold_task_7/fold_{self.fold}_{self.split}.csv')
            df_posts = df_posts[df_posts.index.isin(fold_data['post_id'].values)]
            
            df_fact_check_post_mapping = df_fact_check_post_mapping[df_fact_check_post_mapping['post_id'].isin(fold_data['post_id'].values)]
            
            df_fact_checks_1 = df_fact_checks[df_fact_checks.index.isin(fold_data['fact_check_id'].unique())]
            df_fact_checks_2 = df_fact_checks[~(df_fact_checks.index.isin(fold_data['fact_check_id'].unique()))]
            df_fact_checks = pd.concat([df_fact_checks_1, df_fact_checks_2.sample(n = .30, random_state = 3407)])
    
            print(f'Filtering by split: {len(df_posts)} posts remaining and sampled {len(df_fact_checks)} fact checks')

        print(f'Filtering fact-checks by language: {len(df_fact_checks)} posts remaining.')
        print(f'Filtering posts by language: {len(df_posts)} posts remaining.')
        

        # Create mapping variable
        post_ids = set(df_posts.index)
        fact_check_ids = set(df_fact_checks.index)
        fact_check_post_mapping = set((fact_check_id, post_id) 
                                      for fact_check_id, post_id in df_fact_check_post_mapping.itertuples(index=False, name=None))
        print(f'Mappings remaining: {len(fact_check_post_mapping)}.')

        print(f'Filtering posts.')
        print(f'Posts remaining: {len(df_posts)}')
            
    
        # Create object attributes
        self.fact_check_post_mapping = list(fact_check_post_mapping)

        self.id_to_post = dict()
        for post_id, post_text, ocr_text in zip(df_posts.index, df_posts['text'], df_posts['ocr']):
            texts = list()
            if post_text:
                texts.append(self.maybe_clean_ocr(post_text[1]))
            if ocr_text:
                texts.append(self.maybe_clean_ocr(ocr_text[0][1]))
            self.id_to_post[post_id] = self.clean_text(' '.join(texts))

        self.id_to_fact_check = dict()
        for fact_check_id, claim, title in zip(df_fact_checks.index, df_fact_checks['claim'], df_fact_checks['title']):
            texts = list()
            if claim:
                texts.append(self.maybe_clean_ocr(claim[1]))
            if title:
                texts.append(self.maybe_clean_ocr(title[1]))
            self.id_to_fact_check[fact_check_id] = self.clean_text(' '.join(texts))
            
        return self
    

In [7]:
class DummyDataset(Dataset):
    """
    Very small dataset based on OurDataset with 100 fact-checks an 100 posts.
    """
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def load(self):
        
        dt = OurDataset(split='test').load()
        dt.fact_check_post_mapping = list(dt.fact_check_post_mapping)[:100]
        
        fact_check_ids, post_ids = map(set, zip(*dt.fact_check_post_mapping))
        
        dt.id_to_fact_check = {
            k: v
            for k, v in dt.id_to_fact_check.items()
            if k in fact_check_ids
        }
        
        dt.id_to_post = {
            k: v
            for k, v in dt.id_to_post.items()
            if k in post_ids
        }
        return dt

In [8]:
"""
A library of metric-calculating functions. The default result format is Iterable[List[int]] -- an iterable of results for individual queries (posts).
Each query has a list of ranks assigned, representing the ranks of appropriate documents (fact-checks).
"""

def binary_ci(success: int, total: int, alpha: float = 0.95) -> Tuple[float, float, float]:
    """
    Using Agresti-Coull interval
    
    Return mean and confidence interval (lower and upper bound)
    """
    z = statistics.NormalDist().inv_cdf((1 + alpha) / 2)
    total = total + z**2
    loc = (success + (z**2) / 2) / total
    diameter = z * math.sqrt(loc * (1 - loc) / total)
    return loc, loc - diameter, loc + diameter 


def bootstrap_ci(scores, alpha=0.95) -> Tuple[float, float, float]:
    """
    Bootstrapping based estimate.
    
    Return mean and confidence interval (lower and upper bound)
    """
    loc, scale = norm.fit(scores)    
    bootstrap = [sum(random.choices(scores, k=len(scores))) / len(scores) for _ in range(1000)]
    lower, upper = norm.interval(alpha, *norm.fit(bootstrap))
        
    return loc, lower, upper


def pair_success_at_k(ranks, k=10):
    """
    Pair S@K - How many fact-check-post pairs from all the pairs ended up in the top K.
    """
    values = [rank <= k for query in ranks for rank in query]
    return binary_ci(sum(values), len(values))

        
def post_success_at_k(ranks, k=10):
    """
    Post S@K - For how many posts at least one pair ended up in the top K.
    """
    values = [any(rank <= k for rank in query) for query in ranks]
    return binary_ci(sum(values), len(values))

        
def precision_at_k(ranks, k=10):
    """
    P@K - How many positive hits in the top K
    """
    values = [sum(rank <= k for rank in query) for query in ranks]
    return binary_ci(sum(values), len(values) * k)

        
def mrr(ranks):
    """
    Mean Reciprocal Rank: 1/r for r in ranks
    """
    values = [1 / min(query) for query in ranks]
    return bootstrap_ci(values)

        
def map_(ranks):
    """
    Mean Average Precision: As defined here page 7: https://datascience-intro.github.io/1MS041-2022/Files/AveragePrecision.pdf
    """
    values = [
        np.mean([
            (i + 1) / rank
            for i, rank in enumerate(sorted(query))
        ])
        for query in ranks
    ]
    return bootstrap_ci(values)


def map_k(ranks, k=5):
    values = []
    for query in ranks:
        ap_at_k = 0
        num_correct = 0
        for i, rank in enumerate(query):
            if rank <= k:
                num_correct += 1
                ap_at_k += num_correct / (i+1)
        values.append(ap_at_k)
    return bootstrap_ci(values)


def standard_metrics(ranks: Iterable[List[int]]) -> Dict[str, float]:
    """
    Calculate several metrics and their CIs based on the ranks provided
    
    Attributes:
        ranks - Iterable of results for individual queries. For each query a list of ranks is expected.
    """
        
    return {
        'pair_success_at_10': pair_success_at_k(ranks),
        'post_success_at_10': post_success_at_k(ranks),
        'precision_at_10': precision_at_k(ranks),
    }
    



In [9]:
"""
A library of functions that can be used to evaluate the results produced by _result generators_. Result generators are all the different methods
that can be used to retrieve fact-checks.

`process_result_generator` is the main API that should be used to create evaluation results.

Currently supported:
    bm25 - BM25
    dummy - Returns fixed order of fact-checks. Used for debugging.
    embedding - General method that supports different embedding models and then calculate cosine similarity between the vectors.
"""

def predicted_ranks(predicted_ids: np.array, desired_ids: np.array, default_rank: int = None):
    """
    Return sorted ranks of the `desired_ids` in the `predicted_ids` array.
    
    If `default_rank` is set, the final array will be padded with the value for all the ids that were not present in the `predicted_ids` array.
    """
    
    predicted_ranks = dict()
    
    for desired in desired_ids:
        
        try:
            rank = np.where(predicted_ids == desired)[0][0] + 1  # +1 so that the first item has rank 1, not 0
        except IndexError:
            rank = default_rank
       
        if rank is not None:
            predicted_ranks[desired] = rank
        
    return predicted_ranks


def process_result_generator(gen: Generator, default_rank: int = None, csv_path: str = None):
    """
    Take the results generated from `gen` and process them. By default, only calculate metrics, but dumping the results into a csv file is also supported
    via `csv_path` attribute. For `default_rank` see `predicted_ranks` function.
    """
    
    ranks = list()
    rows = list()
    
    for predicted_ids, desired_ids, post_id in gen:
        post_ranks = predicted_ranks(predicted_ids, desired_ids, default_rank)
        ranks.append(post_ranks.values())
        
        if csv_path:
            rows.append((post_id, post_ranks, predicted_ids[:50]))
            
    print(f'{sum(len(query) for query in ranks)} ranks produced.')
            
    if csv_path:
        pd.DataFrame(rows, columns=['post_id', 'desired_fact_check_ranks', 'predicted_fact_check_ids']).to_csv(csv_path, index=False)
      
    return standard_metrics(ranks)


def result_generator(func):
    """
    This is a decorator function that should be used on result generators. The generators return by default: `predicted_fact_check_ids` and `post_id`.
    Here, the results are enriched with the `desired_fact_check_ids` as indicated by `dataset.fact_check_post_mapping`
    """
    
    def wrapper(dataset, *args, **kwargs):
        
        desired_fact_check_ids = defaultdict(lambda: list())
        for fact_check_id, post_id in dataset.fact_check_post_mapping:
            desired_fact_check_ids[post_id].append(fact_check_id)
        
        for predicted_fact_check_ids, post_id in func(dataset, *args, **kwargs):
            yield predicted_fact_check_ids, desired_fact_check_ids[post_id], post_id
        
    return wrapper

In [10]:
class Vectorizer:
    """
    An abstract class for vectorizers. Vectorizers calculate vectors for texts and handle thir caching so that we do not have to calculate
    the vector for the same text more than once.
    
    Currently supported:
            `SentenceTransformerVectorizer` - Used for models using `sentence_transformers` library.
            `LaserVectorizer` - TBA
            `PytorchVectorizer` - Used for pytorch models (nn.Module).
            
    The main call for `Vectorizer` is `vectorize`. This will calculate the appropriate vectors and store them in the `dict` attribute. The class also
    supports `save` and `load`. `dir_path` is used as path to a folder where there is a `vocab.json` stored with the collection of texts and `vectors.py`
    with a torch tensor of vectors for the texts.
    """
    
    def __init__(self, dir_path: str):
        self.dict = {}
        
        self.dir_path = dir_path
        if dir_path:
            self.vectors_path = os.path.join(dir_path, 'vectors.pt')
            self.vocab_path = os.path.join(dir_path, 'vocab.json')
        
            try:
                self.load()
                print(f'Vector database with {len(self.dict)} records loaded')
            except FileNotFoundError:
                pass
                print(f'No previous database found')

        
    def vectorize(self, texts: List[str], save_if_missing: bool = False, normalize: bool = False) -> torch.tensor:
        """
        The main API point for the users. Try to find the vectors in the existing database. For the missing texts, the vectors will be calculated
        and saved in the `self.dict`. 
        
        Attributes:
            save_if_missing: bool  Should the vectors in `dict` be saved after new vectors are calculated? This makes sense for models that will
                be used more than once.
            normalize: bool  Should the vectors be normalized. Useful for cosine similarity calculations.
        """
        
        missing_texts = list(set(texts) - set(self.dict.keys()))
        
        if missing_texts:
            
            print(f'Calculating {len(missing_texts)} vectors.')
            missing_vectors = self._calculate_vectors(missing_texts)
            for text, vector in zip(missing_texts, missing_vectors):
                self.dict[text] = vector
            
            if save_if_missing:
                self.save()
            
        vectors = torch.vstack([
            self.dict[text]
            for text in texts
        ])
        
        if normalize:
            vectors = torch.nn.functional.normalize(vectors, p=2, dim=1)
            
        return vectors

    
    def _calculate_vectors(self, txts: List[str]) -> torch.tensor:
        """
        Abstract method to be implemented by subclasses.
        """
        raise NotImplementedError

            
    def load(self):
        """
        Load vocab and vectors from appropriate files
        """
        with open(self.vocab_path, 'r', encoding='utf8') as f:
            vocab = json.load(f)
        
        vectors = torch.load(self.vectors_path)
        
        assert len(vocab) == len(vectors)
        
        self.dict = {
            text: vector
            for text, vector in zip(vocab, vectors)
        }
        
        
    def save(self):
        """
        Save vocab and vectors to appropriate files
        """
        os.makedirs(self.dir_path, exist_ok=True)
            
        vocab = list(self.dict.keys())        
        with open(self.vocab_path, 'w', encoding='utf8') as f:
            json.dump(vocab, f)
            
        vectors = torch.vstack(list(self.dict.values()))
        torch.save(vectors, self.vectors_path)
            

In [11]:
def slice_text(text, window_type, window_size, window_stride=None) -> List[str]:
    """
    Split a `text` into parts using a sliding window. The windows slides either across characters or sentences, based on the value of `window_tyoe`.
    
    Attributes:
        text: str  Text that is to be splitted into windows.
        window_type: str  Either `sentence` or `character`. The basic unit of the windows.
        window_size: int  How many units are in a window.
        window_stride: int  How many units are skipped each time the window moves.
    """

    text = replace_whitespaces(text)
    
    if window_stride is None:
        window_stride = window_size
    
    if window_size < window_stride:
        print(f'Window size ({window_size}) is smaller than stride length ({window_stride}). This will result in missing chunks of text.')

        
    if window_type == 'sentence':       
        text = replace_stops(text)
        sentences = sent_tokenize(text)                
        return [
            ' '.join(sentences[i:i+window_size]) 
            for i in range(0, len(sentences), window_stride)
        ]
    
    elif window_type == 'character':
        return [
            text[i:i+window_size] 
            for i in range(0, len(text), window_stride)
        ]

    
def gen_sliding_window_delimiters(post_lengths: List[int], max_size: int) -> Generator[Tuple[int, int], None, None]:
    """
    Calculate where to split the sequence of `post_lenghts` so that the individual batches do not exceed `max_size`
    """
    range_length = start = cur_sum = 0
    
    for post_length in post_lengths:
        if (range_length + post_length) > max_size: # exceeds memory
            yield (start, start + range_length)
            start = cur_sum
            range_length = post_length
        else: # memory still avail in current split
            range_length += post_length
        cur_sum += post_length
        
    if range_length > 0:
        yield (start, start + range_length)


@result_generator
def embedding_results(
    dataset: Dataset,
    vectorizer_fact_check: Vectorizer,
    vectorizer_post: Vectorizer,
    sliding_window: bool = False,
    sliding_window_pooling: str = 'max',
    sliding_window_size: int = None,
    sliding_window_stride: int = None,
    sliding_window_type: str = None,
    post_split_size: int = 256,
    dtype: torch.dtype = torch.float32,
    device: str = 'cpu',
    save_if_missing: bool = False

):
    """
    Generate results using cosine similarity based on embeddings generated via vectorizers.
    
    Attributes:
        dataset: Dataset
        vectorizer_fact_check: Vectorizer  Vectorizer used to process fact-checks
        vectorizer_post: Vectorizer  Vectorizer used to process posts
        sliding_window: bool  Should sliding window be used or should texts be process without slicing.
        sliding_window_pooling: str  One of 'sum', 'mul', 'mean', 'min', 'max' as defined here: https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html
        sliding_window_size, sliding_window_stride, sliding_window_type:  See `slice_text`
        post_split_size: int  Batch size for post embeddings for sim calculation
        dtype: torch.dtype  Data type in which calculate sim
        device: str  Device on which calculate sim
        save_if_missing: bool  Should the vectors in `dict` be saved after new vectors are calculated? This makes sense for models that will
                be used more than once.
    """
        
    print('Calculating embeddings for fact checks')
    fact_check_embeddings = vectorizer_fact_check.vectorize(
        dataset.id_to_fact_check.values(),
        save_if_missing=save_if_missing,
        normalize=True
    )
    fact_check_embeddings = fact_check_embeddings.transpose(0, 1)  # Rotate for matmul

    fact_check_embeddings = fact_check_embeddings.to(device=device, dtype=dtype)

        
    # We need to split the calculations because of memory limitations, sims matrix alone requires 200k x 25k x 4 = ~20GB RAM 
    # memory = 2**30 # assume 4gb free memory - 2**31 = 2gb for both `sims` and `sorted_ids`
    # post_split_size = memory // len(dataset.id_to_fact_check) // 4  # // 4 because of float32
    post_ids = iter(dataset.id_to_post.keys())
    
    if sliding_window:
        
        print('Splitting posts into windows.')
        windows = [
            slice_text(post, sliding_window_type, sliding_window_size, sliding_window_stride)
            for post in dataset.id_to_post.values()
        ]

        print('Calculating embeddings for the windows')
        post_embeddings = vectorizer_post.vectorize(
            list(itertools.chain(*windows)),
            save_if_missing=save_if_missing,
            normalize=True
        ) 
        
        # We need to split the matrix matmul so that all the windows from each post belong to the same batch.
        post_lengths = [len(post) for post in windows]
        segment_array = torch.tensor([
            i 
            for i, num_windows in enumerate(post_lengths) 
            for _ in range(num_windows)
        ])
        delimiters = list(gen_sliding_window_delimiters(post_lengths, post_split_size))
            
        print('Calculating similarity for data splits')
        
        for start_id, end_id in delimiters:

            sims = torch.mm(
                post_embeddings[start_id:end_id].to(device=device, dtype=dtype), 
                fact_check_embeddings
            )

            segments = segment_array[start_id:end_id]
            segments -= int(segments[0])

            sims = scatter(
                src=sims,
                index=segments,
                dim=0,
                reduce=sliding_window_pooling,
            )

            sorted_ids = torch.argsort(sims, descending=True, dim=1)

            fact_check_ids = {i: fc_id for i, fc_id in enumerate(dataset.id_to_fact_check.keys())}
            for row in sorted_ids:
                row = row.cpu().numpy()
                row = np.vectorize(fact_check_ids.__getitem__)(row)
                yield row, next(post_ids)

          
    else:
        
        print('Calculating embeddings for posts')
        post_embeddings = vectorizer_post.vectorize(
            dataset.id_to_post.values(),
            save_if_missing=save_if_missing,
            normalize=True
        )
        
        print('Calculating similarity for data splits')
        for start_id in range(0, len(dataset.id_to_post), post_split_size):
            end_id = start_id + post_split_size

            sims = torch.mm(
                post_embeddings[start_id:end_id].to(device=device, dtype=dtype), 
                fact_check_embeddings
            )

            # TODO: argsort does not duplicities into account, the results might not be deterministic
            sorted_ids = torch.argsort(sims, descending=True, dim=1)

            fact_check_ids = {i: fc_id for i, fc_id in enumerate(dataset.id_to_fact_check.keys())}
            for row in sorted_ids:
                row = row.cpu().numpy()
                row = np.vectorize(fact_check_ids.__getitem__)(row)
                yield row, next(post_ids)

In [12]:
class SentenceTransformerVectorizer(Vectorizer):
    """
    Vectorizer for `SentenceTransformer` models compatible with `sentence_transformers` library.
    """
    
    def __init__(
        self,
        dir_path: str,
        model_handle: str = None,
        model: SentenceTransformer = None,
        batch_size: int = 32
    ):
        """
        Attributes:
            dir_path: str  Path to cached vectors and vocab files.
            model_handle: str  Name of the model, either a HuggingFace repository handle or path to a local model.
            model: SentenceTransformer  A loaded model -- this option can be used during fine-tuning.
        """
        
        super().__init__(dir_path)
        
        if model_handle:
            self.model = SentenceTransformer(model_handle)
        else:
            self.model = model
            
        self.batch_size = batch_size
            
        assert self.model
        
        
    def _calculate_vectors(self, texts: List[str]) -> torch.tensor:
        
        return self.model.encode(
            texts,
            batch_size=self.batch_size,
            convert_to_numpy=False,
        )

In [13]:
class PytorchVectorizer(Vectorizer):
    """
    Vectorizer for `Pytorch` models.
    """
    
    def __init__(
        self,
        dir_path: str,
        model_handle: str = None,
        model: torch.nn.Module = None,
        tokenizer = None,
        batch_size: int = 32,
        dtype: torch.dtype = torch.float32,
        port_embeddings_to_cpu: bool = True
    ):
        """
        Attributes:
            dir_path: str  Path to cached vectors and vocab files.
            model_handle: str  Name of the model, either a HuggingFace repository handle or path to a local model.
            model: SentenceTransformer  A loaded model -- this option can be used during fine-tuning.
            tokenizer: AutoTokenizer  A tokenizer for the model.
            batch_size: int  Batch size for inference
            dtype: torch.dtype  Inference dtype
            port_embeddings_to_cpu: bool  Whether to move the embeddings to CPU after inference.
        """
        
        super().__init__(dir_path)
        
        if model_handle:
            self.model = torch.load(model_handle)
            self.model.eval()        
        else:
            self.model = model

        assert self.model

        self.device = next(self.model.parameters()).device.type
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.dtype = dtype
        self.port_embeddings_to_cpu = port_embeddings_to_cpu

        
    def _calculate_vectors(self, texts: List[str]) -> torch.tensor:

        @torch.autocast(device_type=self.device.split(':')[0], dtype=self.dtype)
        @torch.no_grad()
        def embedding_pipeline(text: List[str], tokenizer, model, device, max_length = 512):
            tokenized = tokenizer(text, padding=True, truncation=True, max_length=max_length, return_tensors='pt').to(device)
            embeddings = model(**tokenized)
            return embeddings.cpu() if self.port_embeddings_to_cpu else embeddings

        return torch.vstack(
                [
                    embedding_pipeline(
                        texts[i:i+self.batch_size], 
                        self.tokenizer, 
                        self.model, 
                        device=self.device, 
                        max_length=512
                    ) 
                    for i in range(0, len(texts), self.batch_size)
                ]
            )
            

In [14]:
"""
https://github.com/mosaicml/composer/blob/dev/composer/algorithms/ema/ema.py
Exponential Moving Average (EMA) is a model averaging technique that maintains an 
exponentially weighted moving average of the model parameters during training. 
The averaged parameters are used for model evaluation. EMA typically results 
in less noisy validation metrics over the course of training, and sometimes 
increased generalization.
"""

def compute_ema(model: torch.nn.Module, ema_model: torch.nn.Module, smoothing: float = 0.99):
    with torch.no_grad():
        model_params = itertools.chain(model.parameters(), model.buffers())
        ema_model_params = itertools.chain(ema_model.parameters(), ema_model.buffers())

        for ema_param, model_param in zip(ema_model_params, model_params):
            model_param = model_param.detach()
            ema_param.copy_(ema_param * smoothing + (1. - smoothing) * model_param)


In [15]:
"""
https://github.com/mosaicml/composer/blob/dev/composer/algorithms/sam/sam.py
Sharpness-Aware Minimization (SAM) is an optimization algorithm that minimizes 
both the loss and the sharpness of the loss. It finds parameters that lie in 
a neighborhood of low loss. The authors find that this improves model generalization
"""

class SAM(torch.optim.Optimizer):
    """Wraps an optimizer with sharpness-aware minimization (`Foret et al, 2020 <https://arxiv.org/abs/2010.01412>`_).
    See :class:`.SAM` for details.
    Implementation based on https://github.com/davda54/sam
    Args:
        base_optimizer (torch.optim.Optimizer) The optimizer to apply SAM to.
        rho (float, optional): The SAM neighborhood size. Must be greater than 0. Default: ``0.05``.
        epsilon (float, optional): A small value added to the gradient norm for numerical stability. Default: ``1.0e-12``.
        interval (int, optional): SAM will run once per ``interval`` steps. A value of 1 will
            cause SAM to run every step. Steps on which SAM runs take
            roughly twice as much time to complete. Default: ``1``.
    """

    def __init__(
        self,
        base_optimizer: torch.optim.Optimizer,
        rho: float = 0.05,
        epsilon: float = 1.0e-12,
        **kwargs
    ):
        if rho < 0:
            raise ValueError(f'Invalid rho, should be non-negative: {rho}')
        self.base_optimizer = base_optimizer
        defaults = {'rho': rho, 'epsilon': epsilon, **kwargs}
        super(SAM, self).__init__(self.base_optimizer.param_groups, defaults)

    @torch.no_grad()
    def first_step(self):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group['rho'] / (grad_norm + group['epsilon'])
            for p in group['params']:
                if p.grad is None:
                    continue
                e_w = p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]['e_w'] = e_w

    @torch.no_grad()
    def second_step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None or 'e_w' not in self.state[p]:
                    continue
                p.sub_(self.state[p]['e_w'])  # get back to "w" from "w + e(w)"
        self.base_optimizer.step()  # do the actual "sharpness-aware" update

    @torch.no_grad()
    def step(self, step_type: str):
        """ Adjusted to PyTorch mixed precision training framework
        """
        if step_type == 'first':
            self.first_step()
        elif step_type == 'second':
            self.second_step()
        elif step_type == 'skip':
            self.base_optimizer.step()


    def _grad_norm(self):
        norm = torch.norm(torch.stack(
            [p.grad.norm(p=2) for group in self.param_groups for p in group['params'] if p.grad is not None]),
                          p='fro')
        return norm

In [16]:
"""
https://github.com/mosaicml/composer/blob/dev/composer/algorithms/layer_freezing/layer_freezing.py
Layer Freezing gradually makes early modules untrainable ("freezing" them), saving the cost of 
backpropagating to and updating frozen modules. The hypothesis behind Layer Freezing 
is that early layers may learn their features sooner than later layers, meaning they 
do not need to be updated later in training. Especially for fine-tuning, which is our case.
"""

def freeze_layers(
    model: torch.nn.Module,
    optimizers: Union[Optimizer, Sequence[Optimizer]],
    current_duration: float,
    freeze_start: float = 0.5,
    freeze_level: float = 1.0,
) -> Tuple[int, float]:
    """Progressively freeze the layers of the network in-place
    during training, starting with the earlier layers.
    Example:
         .. testcode::
            from composer.algorithms.layer_freezing import freeze_layers
            freeze_depth, feeze_level = freeze_layers(
                                            model=model,
                                            optimizers=optimizer,
                                            current_duration=0.5,
                                            freeze_start=0.0,
                                            freeze_level=1.0
                                        )
    Args:
        model (torch.nn.Module): The model being trained.
        optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer]):
            The optimizers used during training.
        current_duration (float): The fraction, in ``[0, 1)`` of the training process complete.
        freeze_start (float, optional): The fraction of the training process in ``[0, 1)`` to run
            before freezing begins. Default: ``0.5``.
        freeze_level (float, optional): The maximum fraction of layers on ``[0, 1)`` to freeze.
            Default: ``1.0``.
    Return:
        (int, float): The number of layers frozen, and the percentage of the total model frozen.
    """
    # Flatten out the layers
    flat_children = []
    _get_layers(model, flat_children)
    # Determine how many layers to freeze
    freeze_percentage = _freeze_schedule(current_duration=current_duration,
                                         freeze_start=freeze_start,
                                         freeze_level=freeze_level)
    freeze_depth = int(freeze_percentage * len(flat_children[0:-1]))

    # Freeze the parameters in the chosen layers
    for i, child in enumerate(flat_children[0:-1]):
        if i < freeze_depth:
            for p in child.parameters():
                _remove_param_from_optimizers(p, optimizers)
                # Do not compute gradients for this param.
                p.requires_grad = False

    return freeze_depth, freeze_percentage


def _freeze_schedule(current_duration: float, freeze_start: float, freeze_level: float) -> float:
    """Implements a linear schedule for freezing.
    The schedule is linear and begins with no freezing and linearly
    increases the fraction of layers frozen, reaching the fraction specified by ``freeze_level`` at the end of training.
    The start of freezing is given as a fraction of the total training duration and is set with ``freeze_start``.
    Args:
        current_duration (float): The elapsed training duration.
        freeze_start (float): The fraction of training to run before freezing begins.
        freeze_level (float): The maximum fraction of levels to freeze.
    """
    # No freezing if the current epoch is less than this
    if current_duration <= freeze_start:
        return 0.0
    # `Calculate the total time for freezing to occur
    total_freezing_time = 1.0 - freeze_start
    # Calculate the amount of freezing time that has elapsed
    freezing_time_elapsed = current_duration - freeze_start
    # Calculate the fraction of the freezing time elapsed.
    freezing_time_elapsed_frac = freezing_time_elapsed / total_freezing_time
    # Scale this fraction by the amount of freezing to do.
    return freeze_level * freezing_time_elapsed_frac


def _get_layers(module: torch.nn.Module, flat_children: List[torch.nn.Module]):
    """Helper function to get all submodules.
    Does a depth first search to flatten out modules which
    contain parameters.
    Args:
        module (torch.nn.Module): Current module to search.
        flat_children (List[torch.nn.Module]): List containing modules.
    """
    # Check if given module has no children and parameters.
    if (len(list(module.children())) == 0 and len(list(module.parameters())) > 0):
        flat_children.append(module)
    else:
        # Otherwise, continue the search over its children.
        for child in module.children():
            _get_layers(child, flat_children)


def _remove_param_from_optimizers(p: torch.nn.Parameter, optimizers: Union[Optimizer, Sequence[Optimizer]]):
    """Helper function to freeze the training of a parameter.
    To freeze a parameter, it must be removed from the optimizer,
    otherwise momentum and weight decay may still be applied.
    Args:
        p (torch.nn.Parameter): The parameter being frozen.
        optimizers (torch.optim.Optimizer | Sequence[torch.optim.Optimizer]): The optimizers used during training.
    """
    # Search over params in the optimizers to find and remove the
    # given param. Necessary due to the way params are stored.
    for optimizer in ensure_tuple(optimizers):
        for group in optimizer.param_groups:
            group['params'] = list(filter(lambda x: id(x) != id(p), group['params']))


def ensure_tuple(x):
    """Converts ``x`` into a tuple.
    * If ``x`` is ``None``, then ``tuple()`` is returned.
    * If ``x`` is a tuple, then ``x`` is returned as-is.
    * If ``x`` is a list, then ``tuple(x)`` is returned.
    * If ``x`` is a dict, then ``tuple(v for v in x.values())`` is returned.
    Otherwise, a single element tuple of ``(x,)`` is returned.
    Args:
        x (Any): The input to convert into a tuple.
    Returns:
        tuple: A tuple of ``x``.
    """
    if x is None:
        return ()
    if isinstance(x, (str, bytes, bytearray)):
        return (x,)
    if isinstance(x, collections.abc.Sequence):
        return tuple(x)
    if isinstance(x, dict):
        return tuple(x.values())
    return (x,)

In [17]:
class MNRloss(nn.Module):
    def __init__(self, label_smoothing=0):
        super().__init__()
        self.loss_f = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

    def forward(self, sentence_embedding_A: torch.Tensor, sentence_embedding_B: torch.Tensor):
        # Compute similarity matrix
        scores = torch.mm(sentence_embedding_A, sentence_embedding_B.transpose(0, 1))
        # Compute labels
        labels = torch.arange(len(scores), dtype=torch.long, device=scores.device)
        return self.loss_f(scores, labels)

In [18]:
class Model(nn.Module):
    def __init__(self, model_name, pooling='default'):
        super(Model, self).__init__()
        self.pooling = pooling
        
        config = AutoConfig.from_pretrained(model_name)
        
        if isinstance(config, T5Config):         
            from transformers import T5EncoderModel
            T5EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"]
            self.model = T5EncoderModel.from_pretrained(model_name, config=config)
        else:
            self.model = AutoModel.from_pretrained(model_name)
            

    def forward(self, input_ids, attention_mask):
        model_output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        if self.pooling == 'default':
            return model_output[1]
        elif self.pooling == 'mean':       
            return self.mean_pooling(model_output[0], attention_mask)

        
    @staticmethod
    def mean_pooling(token_embeddings: torch.Tensor, attention_mask: torch.Tensor):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    
    
def get_model_tokenizer(model_name: str, **kwargs):
    return Model(model_name, **kwargs), AutoTokenizer.from_pretrained(model_name)

In [19]:
def get_opimizer_scheduler(model, cfg: Config):
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)]},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    
    optimizer = _get_optimizer(optimizer_grouped_parameters, cfg)
    scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
        optimizer,
        num_warmup_steps=cfg.train.warmup_steps,
        num_training_steps=cfg.train.num_steps,
        num_cycles=1.0    
    )
    
    if cfg.optimizer.sam:
        return SAM(optimizer, rho=cfg.optimizer.sam_rho)
    else:
        return optimizer, scheduler


def _get_optimizer(parameters, cfg: Config):
    if cfg.optimizer.name == 'adamw':
        return torch.optim.AdamW(
            parameters, 
            lr=cfg.optimizer.lr,
            weight_decay=cfg.optimizer.weight_decay
        )
    else:
        raise Exception(f'Wrong optimizer name {cfg.optimizer.name}')

In [20]:
BATCH = Tuple[Dict[str, torch.Tensor]]


def safe_mean(x, round_to=4):
    try:
        return round(mean(x), round_to)
    except StatisticsError:
        return None


def set_seed(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def collate_fn(batch, tokenizer, cfg: Config):
    p_text, fc_text = zip(*batch)
    return (
        tokenizer(list(p_text), padding=True, truncation=True, max_length=cfg.train.p_max_seq_length, return_tensors='pt'),
        tokenizer(list(fc_text), padding=True, truncation=True, max_length=cfg.train.fc_max_seq_length, return_tensors='pt')
    )


def train_step(step, model, batch: BATCH, loss_fn, optimizer, scheduler, scaler, ctx_autocast, cfg: Config):
    model.train()
    posts_encoded, fact_checks_encoded = batch
    posts_encoded, fact_checks_encoded = posts_encoded.to(cfg.train.device), fact_checks_encoded.to(cfg.train.device)
    
    def fw_bw():
        optimizer.zero_grad(set_to_none=True)
        with ctx_autocast:
            loss = loss_fn(model(**fact_checks_encoded), model(**posts_encoded))
        scaler.scale(loss).backward()
        if cfg.optimizer.clip_value is not None:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.optimizer.clip_value)
        return loss  
      
    if isinstance(optimizer, SAM):
        step_type = ('first', 'second') if (step % cfg.optimizer.sam_step) == 0 else ('skip',)
        for s in step_type:
            loss = fw_bw()
            scaler.step(optimizer, step_type=s)
            scaler.update()
    else:
        loss = fw_bw()
        scaler.step(optimizer)
        scaler.update()

    scheduler.step()
    return loss.item()


def evaluate_datasets(datasets, model, tokenizer, run_path, save_vectors):
    model.eval()
    results = {}
        
    vectorizer_path = run_path if save_vectors else None
    vct = PytorchVectorizer(dir_path=vectorizer_path, model=model, tokenizer=tokenizer, batch_size=256, port_embeddings_to_cpu=True)
    
    for dataset_name, dataset in datasets.items():
        result_gen = embedding_results(dataset, vct, vct, post_split_size=512, device='cuda')
        metrics = process_result_generator(result_gen, csv_path=os.path.join(run_path, dataset_name + '.csv'))
        for value_name, values in metrics.items():
            v_mean, v_ci_lower, v_ci_upper = values
            results[f'{dataset_name}_{value_name}_mean'] = v_mean
            results[f'{dataset_name}_{value_name}_lower'] = v_ci_lower 
            results[f'{dataset_name}_{value_name}_upper'] = v_ci_upper  
    return results


def train(cfg, train_dataset, dev_datasets, test_datasets = None):

    set_seed(cfg.seed)

    if cfg.train.output_path is not None:
        run_path = os.path.join(cfg.train.output_path, cfg.timestamp)
        os.makedirs(run_path)
        # save config
        with open(os.path.join(run_path, 'config.yaml') ,'w') as f:
            yaml.dump(cfg.to_dict(), f, default_flow_style=False)

    print('Get model')
    model, tokenizer = get_model_tokenizer(cfg.model.name, pooling=cfg.model.pooling)
    model = model.to(cfg.train.device)

    # Only if EMA is used
    if cfg.train.ema:
        ema_model = copy.deepcopy(model)

    print('Prepare dataloader')
    partial_collate_fn = partial(collate_fn, tokenizer=tokenizer, cfg=cfg)
    train_dataloader = DataLoader(
        train_dataset, 
        batch_size=cfg.train.batch_size, 
        shuffle=True, 
        collate_fn=partial_collate_fn, 
        pin_memory=True,
        drop_last=True
    )
    cfg.train.num_steps = len(train_dataloader) * cfg.train.num_epochs
    # If evaluation is done after each epoch
    if cfg.train.evaluation_steps == -1:
        cfg.train.evaluation_steps = len(train_dataloader)

    print('Prepare optimizer and loss')
    optimizer, scheduler = get_opimizer_scheduler(model, cfg)
    loss_fn = MNRloss(cfg.train.label_smoothing)
    
    print('Train')
    step = 0
    training_loss = []

    ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[cfg.train.dtype]
    ctx_autocast = torch.autocast(device_type=cfg.train.device.split(':')[0], dtype=ptdtype)
    scaler = torch.cuda.amp.GradScaler(enabled=(cfg.train.dtype == 'float16'))

    for epoch in range(cfg.train.num_epochs):
        print(f"EPOCH: {epoch}")
        for batch in train_dataloader:
            step += 1

            train_loss_value = train_step(step, model, batch, loss_fn, optimizer, scheduler, scaler, ctx_autocast, cfg)
            training_loss.append(train_loss_value)
            
            # Loss reporting
            if (step % cfg.train.metrics_window_size) == 0: 
                agg_loss = mean(training_loss)
                print(
                    f'step {step}/{cfg.train.num_steps}: train/loss-mean{cfg.train.metrics_window_size} {agg_loss:.4f}'
                )
                print({
                    f'train/loss-last-step{cfg.train.metrics_window_size}': train_loss_value, 
                    'lr': optimizer.param_groups[0]['lr']
                    }
                )
                training_loss = []

            # Only if EMA is used
            if cfg.train.ema:
                compute_ema(model, ema_model, smoothing=0.99)
            
            # Evaluation reporting
            if (step % cfg.train.evaluation_steps == 0) or (cfg.train.num_steps == step):
                step_path = os.path.join(run_path, str(step))
                os.makedirs(step_path)

                eval_model = ema_model if cfg.train.ema else model
                # results = evaluate_datasets({**dev_datasets, **test_datasets}, eval_model, tokenizer, step_path, save_vectors=cfg.train.save_vectors)
                results = evaluate_datasets({**dev_datasets}, eval_model, tokenizer, step_path, save_vectors=cfg.train.save_vectors)

                mean_dev_ps = safe_mean([results[f'{dataset_name}_pair_success_at_10_mean'] for dataset_name in dev_datasets.keys()])
                # mean_train_ps = safe_mean([results[f'{dataset_name}_pair_success_at_10_mean'] for dataset_name in test_datasets.keys()])
                print(
                    f'step {step}/{cfg.train.num_steps}: dev/mean-ps@10 {mean_dev_ps}'
                )

                if cfg.train.save_models:
                    torch.save(eval_model.state_dict(), os.path.join(step_path, 'model.pt'))

        # Applying layer freezing at the end of the epoch
        if cfg.train.freezing:
            _, feeze_level = freeze_layers(
                model=model, optimizers=optimizer, 
                current_duration=epoch / cfg.train.num_epochs, 
                freeze_start=0.0, freeze_level=1.0
            )
            print({'step': step, 'feeze_level': feeze_level})

In [21]:
train_dataset = OurDataset(split='train', fold=1).load()
dev_dataset = {'dev_eng': OurDataset(split='test', fold=1).load()}

cfg = Config()
cfg.train.num_epochs = 5
cfg.model.name = 'intfloat/multilingual-e5-large-instruct'
cfg.model.pooling = 'mean'
train(cfg, train_dataset, dev_dataset)

Loading fact-checks.
153743 loaded.
Loading posts.
24431 loaded.
Loading fact-check-post mapping.
25743 loaded.
Filtering by split: 17590 posts remaining and sampled 44223 fact checks
Filtering fact-checks by language: 44223 posts remaining.
Filtering posts by language: 17590 posts remaining.
Mappings remaining: 20594.
Filtering posts.
Posts remaining: 17590
Filtering by split: 4398 posts remaining and sampled 44223 fact checks
Filtering fact-checks by language: 44223 posts remaining.
Filtering posts by language: 4398 posts remaining.
Mappings remaining: 5149.
Filtering posts.
Posts remaining: 4398
Get model


config.json:   0%|          | 0.00/690 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/964 [00:00<?, ?B/s]

Prepare dataloader
Prepare optimizer and loss
Train
EPOCH: 0


  scaler = torch.cuda.amp.GradScaler(enabled=(cfg.train.dtype == 'float16'))


step 128/12870: train/loss-mean128 3.0170
{'train/loss-last-step128': 9.685717259344528e-07, 'lr': 1.6000000000000001e-06}
step 256/12870: train/loss-mean128 0.5507
{'train/loss-last-step128': 0.0, 'lr': 3.2000000000000003e-06}
step 384/12870: train/loss-mean128 0.2481
{'train/loss-last-step128': 0.0, 'lr': 4.800000000000001e-06}
step 512/12870: train/loss-mean128 0.2319
{'train/loss-last-step128': 2.9280853271484375, 'lr': 4.999004860230925e-06}
step 640/12870: train/loss-mean128 0.1453
{'train/loss-last-step128': 1.370440125465393, 'lr': 4.995431569514878e-06}
step 768/12870: train/loss-mean128 0.1941
{'train/loss-last-step128': 0.0, 'lr': 4.989263533058938e-06}
step 896/12870: train/loss-mean128 0.1804
{'train/loss-last-step128': 1.3749001026153564, 'lr': 4.980507164377509e-06}
step 1024/12870: train/loss-mean128 0.1583
{'train/loss-last-step128': 0.06668083369731903, 'lr': 4.969171568328689e-06}
step 1152/12870: train/loss-mean128 0.1198
{'train/loss-last-step128': 1.49011603056692