In [4]:
!pip install -q pyterrier
!pip install -q unidecode

In [5]:
import torch
from typing import Optional, Tuple

def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
    if dim < 0:
        dim = other.dim() + dim
    if src.dim() == 1:
        for _ in range(0, dim):
            src = src.unsqueeze(0)
    for _ in range(src.dim(), other.dim()):
        src = src.unsqueeze(-1)
    src = src.expand(other.size())
    return src


def scatter_sum(src: torch.Tensor,
                index: torch.Tensor,
                dim: int = -1,
                out: Optional[torch.Tensor] = None,
                dim_size: Optional[int] = None) -> torch.Tensor:
    index = broadcast(index, src, dim)
    if out is None:
        size = list(src.size())
        if dim_size is not None:
            size[dim] = dim_size
        elif index.numel() == 0:
            size[dim] = 0
        else:
            size[dim] = int(index.max()) + 1
        out = torch.zeros(size, dtype=src.dtype, device=src.device)
        return out.scatter_add_(dim, index, src)
    else:
        return out.scatter_add_(dim, index, src)


def scatter_add(src: torch.Tensor,
                index: torch.Tensor,
                dim: int = -1,
                out: Optional[torch.Tensor] = None,
                dim_size: Optional[int] = None) -> torch.Tensor:
    return scatter_sum(src, index, dim, out, dim_size)


def scatter_mul(src: torch.Tensor,
                index: torch.Tensor,
                dim: int = -1,
                out: Optional[torch.Tensor] = None,
                dim_size: Optional[int] = None) -> torch.Tensor:
    return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size)


def scatter_mean(src: torch.Tensor,
                 index: torch.Tensor,
                 dim: int = -1,
                 out: Optional[torch.Tensor] = None,
                 dim_size: Optional[int] = None) -> torch.Tensor:
    out = scatter_sum(src, index, dim, out, dim_size)
    dim_size = out.size(dim)

    index_dim = dim
    if index_dim < 0:
        index_dim = index_dim + src.dim()
    if index.dim() <= index_dim:
        index_dim = index.dim() - 1

    ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
    count = scatter_sum(ones, index, index_dim, None, dim_size)
    count[count < 1] = 1
    count = broadcast(count, out, dim)
    if out.is_floating_point():
        out.true_divide_(count)
    else:
        out.div_(count, rounding_mode='floor')
    return out


def scatter_min(
        src: torch.Tensor,
        index: torch.Tensor,
        dim: int = -1,
        out: Optional[torch.Tensor] = None,
        dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
    return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size)


def scatter_max(
        src: torch.Tensor,
        index: torch.Tensor,
        dim: int = -1,
        out: Optional[torch.Tensor] = None,
        dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
    return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size)


def scatter(src: torch.Tensor,
            index: torch.Tensor,
            dim: int = -1,
            out: Optional[torch.Tensor] = None,
            dim_size: Optional[int] = None,
            reduce: str = "sum") -> torch.Tensor:
    r"""
    |

    .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/
            master/docs/source/_figures/add.svg?sanitize=true
        :align: center
        :width: 400px

    |

    Reduces all values from the :attr:`src` tensor into :attr:`out` at the
    indices specified in the :attr:`index` tensor along a given axis
    :attr:`dim`.
    For each value in :attr:`src`, its output index is specified by its index
    in :attr:`src` for dimensions outside of :attr:`dim` and by the
    corresponding value in :attr:`index` for dimension :attr:`dim`.
    The applied reduction is defined via the :attr:`reduce` argument.

    Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional
    tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})`
    and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional
    tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`.
    Moreover, the values of :attr:`index` must be between :math:`0` and
    :math:`y - 1`, although no specific ordering of indices is required.
    The :attr:`index` tensor supports broadcasting in case its dimensions do
    not match with :attr:`src`.

    For one-dimensional tensors with :obj:`reduce="sum"`, the operation
    computes

    .. math::
        \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j

    where :math:`\sum_j` is over :math:`j` such that
    :math:`\mathrm{index}_j = i`.

    .. note::

        This operation is implemented via atomic operations on the GPU and is
        therefore **non-deterministic** since the order of parallel operations
        to the same value is undetermined.
        For floating-point variables, this results in a source of variance in
        the result.

    :param src: The source tensor.
    :param index: The indices of elements to scatter.
    :param dim: The axis along which to index. (default: :obj:`-1`)
    :param out: The destination tensor.
    :param dim_size: If :attr:`out` is not given, automatically create output
        with size :attr:`dim_size` at dimension :attr:`dim`.
        If :attr:`dim_size` is not given, a minimal sized output tensor
        according to :obj:`index.max() + 1` is returned.
    :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mul"`,
        :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`)

    :rtype: :class:`Tensor`

    .. code-block:: python

        from torch_scatter import scatter

        src = torch.randn(10, 6, 64)
        index = torch.tensor([0, 1, 0, 1, 2, 1])

        # Broadcasting in the first and last dim.
        out = scatter(src, index, dim=1, reduce="sum")

        print(out.size())

    .. code-block::

        torch.Size([10, 3, 64])
    """
    if reduce == 'sum' or reduce == 'add':
        return scatter_sum(src, index, dim, out, dim_size)
    if reduce == 'mul':
        return scatter_mul(src, index, dim, out, dim_size)
    elif reduce == 'mean':
        return scatter_mean(src, index, dim, out, dim_size)
    elif reduce == 'min':
        return scatter_min(src, index, dim, out, dim_size)[0]
    elif reduce == 'max':
        return scatter_max(src, index, dim, out, dim_size)[0]
    else:
        raise ValueError


In [6]:
from dataclasses import dataclass, asdict, field
import os
from typing import Any, Dict, Optional
import yaml
import re
import string
import datetime
import ast
from os.path import join as join_path
from collections import *
import argparse
import copy
import glob
from functools import partial
from statistics import mean, median, StatisticsError
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import get_cosine_with_hard_restarts_schedule_with_warmup
from torch import nn
from transformers import AutoConfig, AutoModel, AutoTokenizer, T5Config
import collections
from torch.optim import Optimizer
import itertools
from nltk.tokenize import sent_tokenize
import numpy as np
import json
from typing import *
import math
import random
import statistics
from scipy.stats import norm
from argparse import Namespace
import logging
import os
import shutil
import string
import pandas as pd
import pyterrier as pt
from unidecode import unidecode

@dataclass
class Model:
    name: str = 'sentence-transformers/all-mpnet-base-v2'
    pooling: str = 'default'  # 'mean', 'default'


@dataclass
class Train:
    device: str = 'cuda' 
    dtype: str = 'float16'  # float32, float16, bfloat16
    batch_size: int = 16
    num_epochs: int = 20
    accumulation_steps: int = 4
    metrics_window_size: int = 128
    warmup_steps: int = 400
    evaluation_steps: int = -1  # If -1, after each epoch
    label_smoothing: float = 0.  # [0, 1]
    ema: bool = False  # Exponential moving average
    freezing: bool = False  # Gradually makes early modules untrainable
    output_path: Optional[str] = os.path.join('/kaggle/working/')  # If None, no model checkpointing every evaluation_steps
    save_vectors: bool = False  # Whether to save vectors for evaluation
    save_models: bool = True # Whether to save models
    p_max_seq_length: int = 1024
    fc_max_seq_length: int = 1024

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

@dataclass
class Optimizer:
    name: str = 'adamw' 
    lr: float = 5e-5
    lr_transformer: float = 5e-5
    lr_custom: float = 1e-4
    weight_decay: float = 0.005 
    clip_value: float = 1.0 
    sam: bool = False
    sam_step: float = 1. 
    sam_rho: float = 0.05 

    
@dataclass
class Config:
    model: Model = Model()
    train: Train = Train()
    optimizer: Optimizer = Optimizer()

    seed: int = 3407

    def __init__(self, path: Optional[str] = None):
        self.timestamp = datetime.datetime.now().strftime('%d-%m-%Y-%H-%M-%S')
        if path is not None: 
            with open(path, 'r') as f: self.init_class(yaml.load(f, Loader=yaml.SafeLoader))
        
    def to_dict(self):
        return asdict(self)
    
    def init_class(self, d):
        for name in dir(self):
            if name.startswith('_') or name.endswith('_') or name not in d:
                continue
            attr = getattr(self, name)
            if isinstance(d[name], dict):
                for k, v in d[name].items():
                    setattr(attr, k, v)
            else: 
                setattr(self, name, d[name])

In [7]:
"""
Library of various cleaning-related functions, regular expressions and variables.
"""


simple_latin = string.ascii_lowercase + string.ascii_uppercase
dirty_chars = string.digits + string.punctuation


def is_clean_text(text: str) -> bool:
    """
    Simple text cleaning method.
    """
    dirty = (
        len(text) < 25                                               # Short text
        or
        0.5 < sum(char in dirty_chars for char in text) / len(text)  # More than 50% dirty chars                                            
    )
    return not dirty


# Source: https://gist.github.com/dperini/729294
url_regex = re.compile(
    r'(?:^|(?<![\w\/\.]))'
    r'(?:(?:https?:\/\/|ftp:\/\/|www\d{0,3}\.))'
    r'(?:\S+(?::\S*)?@)?' r'(?:'
    r'(?!(?:10|127)(?:\.\d{1,3}){3})'
    r'(?!(?:169\.254|192\.168)(?:\.\d{1,3}){2})'
    r'(?!172\.(?:1[6-9]|2\d|3[0-1])(?:\.\d{1,3}){2})'
    r'(?:[1-9]\d?|1\d\d|2[01]\d|22[0-3])'
    r'(?:\.(?:1?\d{1,2}|2[0-4]\d|25[0-5])){2}'
    r'(?:\.(?:[1-9]\d?|1\d\d|2[0-4]\d|25[0-4]))'
    r'|'
    r'(?:(?:[a-z\\u00a1-\\uffff0-9]-?)*[a-z\\u00a1-\\uffff0-9]+)'
    r'(?:\.(?:[a-z\\u00a1-\\uffff0-9]-?)*[a-z\\u00a1-\\uffff0-9]+)*'
    r'(?:\.(?:[a-z\\u00a1-\\uffff]{2,}))' r'|' r'(?:(localhost))' r')'
    r'(?::\d{2,5})?'
    r'(?:\/[^\)\]\}\s]*)?',
    flags=re.IGNORECASE,
)


def remove_urls(text: str) -> str:
    return url_regex.sub('', text)


# Source: https://gist.github.com/Nikitha2309/15337f4f593c4a21fb0965804755c41d
emoji_regex = re.compile('['
        u'\U0001F600-\U0001F64F'  # emoticons
        u'\U0001F300-\U0001F5FF'  # symbols & pictographs
        u'\U0001F680-\U0001F6FF'  # transport & map symbols
        u'\U0001F1E0-\U0001F1FF'  # flags (iOS)
        u'\U00002500-\U00002BEF'  # chinese char
        u'\U00002702-\U000027B0'
        u'\U00002702-\U000027B0'
        u'\U000024C2-\U0001F251'
        u'\U0001f926-\U0001f937'
        u'\U00010000-\U0010ffff'
        u'\u2640-\u2642'
        u'\u2600-\u2B55'
        u'\u200d'
        u'\u23cf'
        u'\u23e9'
        u'\u231a'
        u'\ufe0f'  # dingbats
        u'\u3030'
    ']+')


def remove_emojis(text: str) -> str:
    return emoji_regex.sub('', text)


sentence_stop_regex = re.compile('['
    u'\u002e' # full stop
    u'\u2026' # ellipsis
    u'\u061F' # arabic question mark
    u'\u06D4' # arabic full stop
    u'\u2022' # bullet point
    u'\u3002' # chinese period
    u'\u25CB' # white circle
    '\|'      # pipe
']+')


def replace_stops(text: str) -> str:
    """
    Replaces some characters that are being used to end sentences. Used for sentence segmentation with sliding windows.
    """
    return sentence_stop_regex.sub('.', text)


whitespace_regex = re.compile(r'\s+')


def replace_whitespaces(text: str) -> str:
    return whitespace_regex.sub(' ', text)


def clean_ocr(ocr: str) -> str:
    """
    Remove all lines that are shorter than 6 and have more than 50% `dirty_chars`.
    """
    return '\n'.join(
        line
        for line in ocr.split('\n')
        if len(line) > 5 and sum(char in dirty_chars for char in line) / len(line) < 0.5
    )


def clean_twitter_picture_links(text):
    """
    Replaces links to picture in twitter post only with 'pic'. 
    """
    return re.sub(r'pic.twitter.com/\S+', 'pic', text)


def clean_twitter_links(text):
    """
    Replaces twitter links with 't.co'.
    """
    return re.sub(r'\S+//t.co/\S+', 't.co', text)


def remove_elongation(text):
    """
    Replaces any occurrence of a string of consecutive identical non-space 
    characters (at least three in a row) with just one instance of that character.
    """
    text = re.sub(r'(\S+)\1{2,}', r'\1', text)
    return text

def safe_literal_eval(value):
    try:
        return ast.literal_eval(str(value))
    except (ValueError, SyntaxError):
        return value  # Or `None`, depending on how you want to handle it

In [8]:
"""
Library of custom types used for datasets. Some basic functions over these types are also included.
"""

Id2FactCheck = Dict[int, str]
Id2Post = Dict[int, str]
FactCheckPostMapping = List[Tuple[int, int]]

Language = str  # ISO 639-3 language code
LanguageDistribution = Dict[Language, float]  # Confidence for language identification. Should sum to [0, 1].

OriginalText = str
EnglishTranslation = str
TranslatedText = Tuple[OriginalText, EnglishTranslation, LanguageDistribution]

Instance = Tuple[Optional[datetime.datetime], str]  # When and where was a fact-check or post published


def is_in_distribution(language: Language, distribution: LanguageDistribution, threshold: float = 0.2) -> bool:
    """
    Check whether `language` is in a `distribution` with more than `treshold` x 100%
    """
    return next(
        (
            percentage >= threshold
            for distribution_language, percentage in distribution
            if distribution_language == language
        ),
        False
    )


def combine_distributions(texts: Iterable[TranslatedText]) -> LanguageDistribution:
    """
    Combine `LanguageDistribution`s from multiple `TranslatedText`s taking the length of the text into consideration.
    """
    total_length = sum(len(text[0]) for text in texts)
    distribution = defaultdict(lambda: 0)
    for original_text, _, text_distribution in texts:
        for language, percentage in text_distribution:
            distribution[language] += percentage * len(original_text) / total_length
    return list(distribution.items())

In [9]:
class Dataset:
    """Dataset
    
    Abstract class for datasets. Subclasses should implement `load` function that load `id_to_fact_check`, `id_to_post`,
    and `fact_check_post_mapping` object attributes. The class also implemenets basic cleaning methods that might be
    reused.
    
    Attributes:
        clean_ocr: bool = True  Should cleaning of OCRs be performed
        remove_emojis: bool = False  Should emojis be removed from the texts?
        remove_urls: bool = True  Should URLs be removed from the texts?
        replace_whitespaces: bool = True  Should whitespaces be replaced by a single space whitespace?
        clean_twitter: bool = True  
        remove_elongation: bool = False  Should occurrence of a string of consecutive identical non-space 
    characters (at least three in a row) with just one instance of that character?

        After `load` is called, following attributes are accesible:
                fact_check_post_mapping: list[tuple[int, int]]  List of Factcheck-Post id pairs.
                id_to_fact_check: dict[int, str]  Factcheck id -> Factcheck text
                id_to_post: dict[int, str]  Post id -> Post text
                
    Methods:
        clean_text: Performs text cleaning based on initialization attributes.
        maybe_clean_ocr: Perform OCR-specific text cleaning, if `self.clean_ocr`
        load: Abstract method. To be implemented by the subclasses.
        
    """
    
    # The default values here are based on our preliminary experiments. Might not be the best for all cases.
    def __init__(
        self,
        clean_ocr: bool = True,
        dataset: str = None,  # Here to read and discard the `dataset` field from the argparser
        remove_emojis: bool = True,
        remove_urls: bool = True,
        replace_whitespaces: bool = True,
        clean_twitter: bool = True,
        remove_elongation: bool = False
    ):
        self.clean_ocr = clean_ocr
        self.remove_emojis = remove_emojis
        self.remove_urls = remove_urls
        self.replace_whitespaces = replace_whitespaces
        self.clean_twitter = clean_twitter
        self.remove_elongation = remove_elongation
        
        
    def __len__(self):
        return len(self.fact_check_post_mapping)

    
    def __getitem__(self, idx):
        p_id, fc_id = self.fact_check_post_mapping[idx]
        return self.id_to_fact_check[fc_id], self.id_to_post[p_id]

        
    def clean_text(self, text):
        
        if self.remove_urls:
            text = remove_urls(text)

        if self.remove_emojis:
            text = remove_emojis(text)

        if self.replace_whitespaces:
            text = replace_whitespaces(text)
        
        if self.clean_twitter:
            text = clean_twitter_picture_links(text)
            text = clean_twitter_links(text)
        
        if self.remove_elongation:
            text = remove_elongation(text)

        return text.strip()        
        
        
    def maybe_clean_ocr(self, ocr):
        if self.clean_ocr:
            return clean_ocr(ocr)
        return ocr
        
    
    def __getattr__(self, name):
        if name in {'id_to_fact_check', 'id_to_post', 'fact_check_post_mapping'}:
            raise AttributeError(f"You have to `load` the dataset first before using '{name}'")
        raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")

        
    def load(self):
        raise NotImplementedError
        


In [10]:

class OurDataset(Dataset):
    """Our dataset
    
    Class for our dataset that can load different variants. It requires data from `drive/datasets/ours`.

    Initialization Attributes:
        crosslingual: bool  If `True`, only crosslingual pairs (fact-check and post in different languages) are loaded.
        fact_check_fields: Iterable[str]  List of fields used to generate the final `str` representation for fact-checks. Supports `claim` and `title`.
        fact_check_language: Optional[Language]  If a `Language` is specified, only fact-checks with that language are selected.
        language: Optional[Language]  If a `Language` is specified, only fact-checks and posts with that language are selected.
        post_language: Optional[Language]  If a `Language` is specified, only posts with that language are selected.
        split: `train`, `test` or `dev`. `None` means all the samples.
        version: 'original' or 'english'. Language version of the dataset.
        
        Also check `Dataset` attributes.
        
        After `load` is called, following attributes are accesible:
            fact_check_post_mapping: list[tuple[int, int]]  List of Factcheck-Post id pairs.
            id_to_fact_check: dict[int, str]  Factcheck id -> Factcheck text
            id_to_post: dict[int, str]  Post id -> Post text
        

    Methods:
        load: Loads the data from the csv files. Populates `id_to_fact_check`, `id_to_post` and `fact_check_post_mapping` attributes.
    """
        
    our_dataset_path = join_path('/kaggle/input/semeval-data')
    csvs_loaded = False

    
    def __init__(
        self,
        crosslingual: bool = False,
        fact_check_fields: Iterable[str] = ('claim', 'title'),
        fact_check_language: Optional[Language] = None,
        language: Optional[Language] = None,
        post_language: Optional[Language] = None,
        split: Optional[str] = None,
        # version: str = 'original',
        fold: int = 0,
        **kwargs
    ):
        super().__init__(**kwargs)
        
        assert all(field in ('claim', 'title') for field in fact_check_fields)
        assert split in ('test', 'train')
        
        # self.crosslingual = crosslingual
        self.split = split
        self.fold = fold
        self.language = language
        
    @classmethod
    def maybe_load_csvs(cls):
        """
        Load the csvs and store them as class variables. When individual objects are initialized, they can reuse the same
        pre-loaded dataframes without costly text parsing.
        
        `OurDataset.csvs_loaded` is a flag indicating whether the csvs are already loaded.
        """
        
        if cls.csvs_loaded:
            return
        
        posts_path = join_path(cls.our_dataset_path, 'test_posts_text.csv')
        fact_checks_path = join_path(cls.our_dataset_path, 'test_fact_checks_text.csv')
        
        for path in [posts_path, fact_checks_path]:
            assert os.path.isfile(path)
        
        print('Loading fact-checks.')
        cls.df_fact_checks = pd.read_csv(fact_checks_path).fillna('').set_index('fact_check_id')
        print(f'{len(cls.df_fact_checks)} loaded.')

            
        print('Loading posts.')
        cls.df_posts = pd.read_csv(posts_path).fillna('').set_index('post_id')
        print(f'{len(cls.df_posts)} loaded.')

        cls.csvs_loaded = True
        

    def load(self):
        
        self.maybe_load_csvs()
        
        df_posts = self.df_posts.copy()
        df_fact_checks = self.df_fact_checks.copy()

        print(df_fact_checks['post_lang'].value_counts())
        
        if self.language != 'cross':

            df = pd.read_json('/kaggle/input/semeval-data/SemEval_Task7_Test_Phase/tasks.json').reset_index()
            lang_fact_checks = df[df['index']==self.language]['monolingual'].values[0]['fact_checks']
            print('task lang fact checks', len(lang_fact_checks))
            lang_post_dev = df[df['index']==self.language]['monolingual'].values[0]['posts_test']
            print('task lang post dev', len(lang_post_dev))

            df_fact_checks = df_fact_checks[df_fact_checks.index.isin(lang_fact_checks)]
            df_posts = df_posts[df_posts.index.isin(lang_post_dev)]
    
            print(f'Filtering by split: {len(df_posts)} posts remaining and sampled {len(df_fact_checks)} fact checks')
            
        else:
            
            df = pd.read_json('/kaggle/input/semeval-data/SemEval_Task7_Test_Phase/tasks.json').reset_index()
            lang_fact_checks = df[df['index']=='fact_checks']['crosslingual'].dropna().values[0]
            print('task lang fact checks', len(lang_fact_checks))
            lang_post_dev = df[df['index']=='posts_test']['crosslingual'].dropna().values[0]
            print('task lang post dev', len(lang_post_dev))

            df_fact_checks = df_fact_checks[df_fact_checks.index.isin(lang_fact_checks)]
            df_posts = df_posts[df_posts.index.isin(lang_post_dev)]
    
            print(f'Filtering by split: {len(df_posts)} posts remaining and sampled {len(df_fact_checks)} fact checks')

        print(f'Filtering fact-checks by language: {len(df_fact_checks)} posts remaining.')
        print(f'Filtering posts by language: {len(df_posts)} posts remaining.')
        print(f'Filtering posts.')
        print(f'Posts remaining: {len(df_posts)}')
            

        self.id_to_post = dict()
        for post_id, post_text in zip(df_posts.index, df_posts['clean_text']):
            self.id_to_post[post_id] = self.clean_text(post_text)

        self.id_to_fact_check = dict()
        for fact_check_id, fact_text in zip(df_fact_checks.index, df_fact_checks['clean_text']):
            self.id_to_fact_check[fact_check_id] = self.clean_text(fact_text)
            
        return self
    

In [11]:
df = pd.read_json('/kaggle/input/semeval-data/SemEval_Task7_Test_Phase/tasks.json').reset_index()
lang_fact_checks = df[df['index']=='fact_checks']['crosslingual'].dropna().values[0]
print('task lang fact checks', len(lang_fact_checks))
lang_post_dev = df[df['index']=='posts_test']['crosslingual'].dropna().values[0]
print('task lang post dev', len(lang_post_dev))

task lang fact checks 272447
task lang post dev 4000


In [12]:
def process_result_generator(gen: Generator, default_rank: int = None, csv_path: str = None):
    """
    Take the results generated from `gen` and process them. By default, only calculate metrics, but dumping the results into a csv file is also supported
    via `csv_path` attribute. For `default_rank` see `predicted_ranks` function.
    """
    ranks = list()
    rows = list()
    for predicted_ids, probab_ids, post_id in gen:
        if csv_path:
            rows.append((post_id, probab_ids[:50], predicted_ids[:50]))
    if csv_path:
        pd.DataFrame(rows, columns=['post_id', 'predicted_fact_check_probs',
                                    'predicted_fact_check_ids']).to_csv(csv_path, index=False)
    print('done')
def result_generator(func):
    """
    This is a decorator function that should be used on result generators. The generators return by default: `predicted_fact_check_ids` and `post_id`.
    Here, the results are enriched with the `desired_fact_check_ids` as indicated by `dataset.fact_check_post_mapping`
    """
    def wrapper(dataset, *args, **kwargs):
        for predicted_fact_check_ids, probab_ids, post_id in func(dataset, *args, **kwargs):
            yield predicted_fact_check_ids, probab_ids, post_id
    return wrapper

In [13]:
class Vectorizer:
    """
    An abstract class for vectorizers. Vectorizers calculate vectors for texts and handle thir caching so that we do not have to calculate
    the vector for the same text more than once.
    
    Currently supported:
            `SentenceTransformerVectorizer` - Used for models using `sentence_transformers` library.
            `LaserVectorizer` - TBA
            `PytorchVectorizer` - Used for pytorch models (nn.Module).
            
    The main call for `Vectorizer` is `vectorize`. This will calculate the appropriate vectors and store them in the `dict` attribute. The class also
    supports `save` and `load`. `dir_path` is used as path to a folder where there is a `vocab.json` stored with the collection of texts and `vectors.py`
    with a torch tensor of vectors for the texts.
    """
    
    def __init__(self, dir_path: str):
        self.dict = {}
        
        self.dir_path = dir_path
        if dir_path:
            self.vectors_path = os.path.join(dir_path, 'vectors.pt')
            self.vocab_path = os.path.join(dir_path, 'vocab.json')
        
            try:
                self.load()
                print(f'Vector database with {len(self.dict)} records loaded')
            except FileNotFoundError:
                pass
                print(f'No previous database found')

        
    def vectorize(self, texts: List[str], save_if_missing: bool = False, normalize: bool = False) -> torch.tensor:
        """
        The main API point for the users. Try to find the vectors in the existing database. For the missing texts, the vectors will be calculated
        and saved in the `self.dict`. 
        
        Attributes:
            save_if_missing: bool  Should the vectors in `dict` be saved after new vectors are calculated? This makes sense for models that will
                be used more than once.
            normalize: bool  Should the vectors be normalized. Useful for cosine similarity calculations.
        """
        
        missing_texts = list(set(texts) - set(self.dict.keys()))
        
        if missing_texts:
            
            print(f'Calculating {len(missing_texts)} vectors.')
            missing_vectors = self._calculate_vectors(missing_texts)
            for text, vector in zip(missing_texts, missing_vectors):
                self.dict[text] = vector
            
            if save_if_missing:
                self.save()
            
        vectors = torch.vstack([
            self.dict[text]
            for text in texts
        ])
        
        if normalize:
            vectors = torch.nn.functional.normalize(vectors, p=2, dim=1)
            
        return vectors

    
    def _calculate_vectors(self, txts: List[str]) -> torch.tensor:
        """
        Abstract method to be implemented by subclasses.
        """
        raise NotImplementedError

            
    def load(self):
        """
        Load vocab and vectors from appropriate files
        """
        with open(self.vocab_path, 'r', encoding='utf8') as f:
            vocab = json.load(f)
        
        vectors = torch.load(self.vectors_path)
        
        assert len(vocab) == len(vectors)
        
        self.dict = {
            text: vector
            for text, vector in zip(vocab, vectors)
        }
        
        
    def save(self):
        """
        Save vocab and vectors to appropriate files
        """
        os.makedirs(self.dir_path, exist_ok=True)
            
        vocab = list(self.dict.keys())        
        with open(self.vocab_path, 'w', encoding='utf8') as f:
            json.dump(vocab, f)
            
        vectors = torch.vstack(list(self.dict.values()))
        torch.save(vectors, self.vectors_path)
            

In [14]:
def slice_text(text, window_type, window_size, window_stride=None) -> List[str]:
    """
    Split a `text` into parts using a sliding window. The windows slides either across characters or sentences, based on the value of `window_tyoe`.

    Attributes:
        text: str  Text that is to be splitted into windows.
        window_type: str  Either `sentence` or `character`. The basic unit of the windows.
        window_size: int  How many units are in a window.
        window_stride: int  How many units are skipped each time the window moves.
    """

    text = replace_whitespaces(text)

    if window_stride is None:
        window_stride = window_size

    if window_size < window_stride:
        print(f'Window size ({window_size}) is smaller than stride length ({window_stride}). This will result in missing chunks of text.')


    if window_type == 'sentence':
        text = replace_stops(text)
        sentences = sent_tokenize(text)
        return [
            ' '.join(sentences[i:i+window_size])
            for i in range(0, len(sentences), window_stride)
        ]

    elif window_type == 'character':
        return [
            text[i:i+window_size]
            for i in range(0, len(text), window_stride)
        ]


def gen_sliding_window_delimiters(post_lengths: List[int], max_size: int) -> Generator[Tuple[int, int], None, None]:
    """
    Calculate where to split the sequence of `post_lenghts` so that the individual batches do not exceed `max_size`
    """
    range_length = start = cur_sum = 0

    for post_length in post_lengths:
        if (range_length + post_length) > max_size: # exceeds memory
            yield (start, start + range_length)
            start = cur_sum
            range_length = post_length
        else: # memory still avail in current split
            range_length += post_length
        cur_sum += post_length

    if range_length > 0:
        yield (start, start + range_length)


@result_generator
def embedding_results(
    dataset: Dataset,
    vectorizer_fact_check: Vectorizer,
    vectorizer_post: Vectorizer,
    sliding_window: bool = False,
    sliding_window_pooling: str = 'max',
    sliding_window_size: int = None,
    sliding_window_stride: int = None,
    sliding_window_type: str = None,
    post_split_size: int = 256,
    dtype: torch.dtype = torch.float16,
    device: str = 'cpu',
    save_if_missing: bool = False

):
    """
    Generate results using cosine similarity based on embeddings generated via vectorizers.

    Attributes:
        dataset: Dataset
        vectorizer_fact_check: Vectorizer  Vectorizer used to process fact-checks
        vectorizer_post: Vectorizer  Vectorizer used to process posts
        sliding_window: bool  Should sliding window be used or should texts be process without slicing.
        sliding_window_pooling: str  One of 'sum', 'mul', 'mean', 'min', 'max' as defined here: https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html
        sliding_window_size, sliding_window_stride, sliding_window_type:  See `slice_text`
        post_split_size: int  Batch size for post embeddings for sim calculation
        dtype: torch.dtype  Data type in which calculate sim
        device: str  Device on which calculate sim
        save_if_missing: bool  Should the vectors in `dict` be saved after new vectors are calculated? This makes sense for models that will
                be used more than once.
    """

    print('Calculating embeddings for fact checks')
    fact_check_embeddings = vectorizer_fact_check.vectorize(
        dataset.id_to_fact_check.values(),
        save_if_missing=save_if_missing,
        normalize=True
    )
    fact_check_embeddings = fact_check_embeddings.transpose(0, 1)  # Rotate for matmul

    fact_check_embeddings = fact_check_embeddings.to(device=device, dtype=dtype)


    # We need to split the calculations because of memory limitations, sims matrix alone requires 200k x 25k x 4 = ~20GB RAM
    # memory = 2**30 # assume 4gb free memory - 2**31 = 2gb for both `sims` and `sorted_ids`
    # post_split_size = memory // len(dataset.id_to_fact_check) // 4  # // 4 because of float32
    post_ids = iter(dataset.id_to_post.keys())

    if sliding_window:

        print('Splitting posts into windows.')
        windows = [
            slice_text(post, sliding_window_type, sliding_window_size, sliding_window_stride)
            for post in dataset.id_to_post.values()
        ]

        print('Calculating embeddings for the windows')
        post_embeddings = vectorizer_post.vectorize(
            list(itertools.chain(*windows)),
            save_if_missing=save_if_missing,
            normalize=True
        )

        # We need to split the matrix matmul so that all the windows from each post belong to the same batch.
        post_lengths = [len(post) for post in windows]
        segment_array = torch.tensor([
            i
            for i, num_windows in enumerate(post_lengths)
            for _ in range(num_windows)
        ])
        delimiters = list(gen_sliding_window_delimiters(post_lengths, post_split_size))

        print('Calculating similarity for data splits')

        for start_id, end_id in delimiters:

            sims = torch.mm(
                post_embeddings[start_id:end_id].to(device=device, dtype=dtype),
                fact_check_embeddings
            )

            segments = segment_array[start_id:end_id]
            segments -= int(segments[0])

            sims = scatter(
                src=sims,
                index=segments,
                dim=0,
                reduce=sliding_window_pooling,
            )

            sorted_ids = torch.argsort(sims, descending=True, dim=1)

            fact_check_ids = {i: fc_id for i, fc_id in enumerate(dataset.id_to_fact_check.keys())}
            for row in sorted_ids:
                row = row.cpu().numpy()
                row = np.vectorize(fact_check_ids.__getitem__)(row)
                yield row, next(post_ids)


    else:

        print('Calculating embeddings for posts')
        post_embeddings = vectorizer_post.vectorize(
            dataset.id_to_post.values(),
            save_if_missing=save_if_missing,
            normalize=True
        )

        print('Calculating similarity for data splits')
        for start_id in range(0, len(dataset.id_to_post), post_split_size):
            end_id = start_id + post_split_size

            sims = torch.mm(
                post_embeddings[start_id:end_id].to(device=device, dtype=dtype),
                fact_check_embeddings
            )

            # TODO: argsort does not duplicities into account, the results might not be deterministic
            # sorted_ids = torch.argsort(sims, descending=True, dim=1)
            sorted_sims, sorted_ids = torch.sort(sims, descending=True, dim=1)

            fact_check_ids = {i: fc_id for i, fc_id in enumerate(dataset.id_to_fact_check.keys())}
            # for row in sorted_ids:
            #     row = row.cpu().numpy()
            #     row = np.vectorize(fact_check_ids.__getitem__)(row)
            #     yield row, next(post_ids)

            for ids_row, sims_row in zip(sorted_ids, sorted_sims):
                ids_row = ids_row.cpu().numpy()
                sims_row = sims_row.cpu().numpy()

                fact_ids = np.vectorize(fact_check_ids.__getitem__)(ids_row)
                yield fact_ids, sims_row, next(post_ids)

In [15]:
class PytorchVectorizer(Vectorizer):
    """
    Vectorizer for `Pytorch` models.
    """
    
    def __init__(
        self,
        dir_path: str,
        model_handle: str = None,
        model: torch.nn.Module = None,
        tokenizer = None,
        batch_size: int = 32,
        dtype: torch.dtype = torch.float16,
        port_embeddings_to_cpu: bool = True
    ):
        """
        Attributes:
            dir_path: str  Path to cached vectors and vocab files.
            model_handle: str  Name of the model, either a HuggingFace repository handle or path to a local model.
            model: SentenceTransformer  A loaded model -- this option can be used during fine-tuning.
            tokenizer: AutoTokenizer  A tokenizer for the model.
            batch_size: int  Batch size for inference
            dtype: torch.dtype  Inference dtype
            port_embeddings_to_cpu: bool  Whether to move the embeddings to CPU after inference.
        """
        
        super().__init__(dir_path)
        
        if model_handle:
            self.model = torch.load(model_handle)
            self.model.eval()        
        else:
            self.model = model

        assert self.model

        self.device = next(self.model.parameters()).device.type
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.dtype = dtype
        self.port_embeddings_to_cpu = port_embeddings_to_cpu

        
    def _calculate_vectors(self, texts: List[str]) -> torch.tensor:

        @torch.autocast(device_type=self.device.split(':')[0], dtype=self.dtype)
        @torch.no_grad()
        def embedding_pipeline(text: List[str], tokenizer, model, device, max_length = 512):
            tokenized = tokenizer(text, padding=True, truncation=True, max_length=max_length, return_tensors='pt').to(device)
            embeddings = model(**tokenized)
            return embeddings.cpu() if self.port_embeddings_to_cpu else embeddings

        return torch.vstack(
                [
                    embedding_pipeline(
                        texts[i:i+self.batch_size], 
                        self.tokenizer, 
                        self.model, 
                        device=self.device, 
                    ) 
                    for i in range(0, len(texts), self.batch_size)
                ]
            )
            

In [16]:
"""
https://github.com/mosaicml/composer/blob/dev/composer/algorithms/ema/ema.py
Exponential Moving Average (EMA) is a model averaging technique that maintains an 
exponentially weighted moving average of the model parameters during training. 
The averaged parameters are used for model evaluation. EMA typically results 
in less noisy validation metrics over the course of training, and sometimes 
increased generalization.
"""

def compute_ema(model: torch.nn.Module, ema_model: torch.nn.Module, smoothing: float = 0.99):
    with torch.no_grad():
        model_params = itertools.chain(model.parameters(), model.buffers())
        ema_model_params = itertools.chain(ema_model.parameters(), ema_model.buffers())

        for ema_param, model_param in zip(ema_model_params, model_params):
            model_param = model_param.detach()
            ema_param.copy_(ema_param * smoothing + (1. - smoothing) * model_param)


In [17]:
"""
https://github.com/mosaicml/composer/blob/dev/composer/algorithms/sam/sam.py
Sharpness-Aware Minimization (SAM) is an optimization algorithm that minimizes 
both the loss and the sharpness of the loss. It finds parameters that lie in 
a neighborhood of low loss. The authors find that this improves model generalization
"""

class SAM(torch.optim.Optimizer):
    """Wraps an optimizer with sharpness-aware minimization (`Foret et al, 2020 <https://arxiv.org/abs/2010.01412>`_).
    See :class:`.SAM` for details.
    Implementation based on https://github.com/davda54/sam
    Args:
        base_optimizer (torch.optim.Optimizer) The optimizer to apply SAM to.
        rho (float, optional): The SAM neighborhood size. Must be greater than 0. Default: ``0.05``.
        epsilon (float, optional): A small value added to the gradient norm for numerical stability. Default: ``1.0e-12``.
        interval (int, optional): SAM will run once per ``interval`` steps. A value of 1 will
            cause SAM to run every step. Steps on which SAM runs take
            roughly twice as much time to complete. Default: ``1``.
    """

    def __init__(
        self,
        base_optimizer: torch.optim.Optimizer,
        rho: float = 0.05,
        epsilon: float = 1.0e-12,
        **kwargs
    ):
        if rho < 0:
            raise ValueError(f'Invalid rho, should be non-negative: {rho}')
        self.base_optimizer = base_optimizer
        defaults = {'rho': rho, 'epsilon': epsilon, **kwargs}
        super(SAM, self).__init__(self.base_optimizer.param_groups, defaults)

    @torch.no_grad()
    def first_step(self):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group['rho'] / (grad_norm + group['epsilon'])
            for p in group['params']:
                if p.grad is None:
                    continue
                e_w = p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]['e_w'] = e_w

    @torch.no_grad()
    def second_step(self):
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None or 'e_w' not in self.state[p]:
                    continue
                p.sub_(self.state[p]['e_w'])  # get back to "w" from "w + e(w)"
        self.base_optimizer.step()  # do the actual "sharpness-aware" update

    @torch.no_grad()
    def step(self, step_type: str):
        """ Adjusted to PyTorch mixed precision training framework
        """
        if step_type == 'first':
            self.first_step()
        elif step_type == 'second':
            self.second_step()
        elif step_type == 'skip':
            self.base_optimizer.step()


    def _grad_norm(self):
        norm = torch.norm(torch.stack(
            [p.grad.norm(p=2) for group in self.param_groups for p in group['params'] if p.grad is not None]),
                          p='fro')
        return norm

In [18]:
class MNRloss(nn.Module):
    def __init__(self, label_smoothing=0):
        super().__init__()
        self.loss_f = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

    def forward(self, sentence_embedding_A: torch.Tensor, sentence_embedding_B: torch.Tensor):
        # Compute similarity matrix
        scores = torch.mm(sentence_embedding_A, sentence_embedding_B.transpose(0, 1))
        # Compute labels
        labels = torch.arange(len(scores), dtype=torch.long, device=scores.device)
        return self.loss_f(scores, labels)

In [19]:
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer

class Model(nn.Module):
    def __init__(self, model_name, pooling='default', hidden_size=768, lstm_hidden_size=128):
        super(Model, self).__init__()
        self.pooling = pooling
        
        self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
        
        # self.lstm = nn.LSTM(input_size=hidden_size, hidden_size=lstm_hidden_size, batch_first=True, bidirectional=True)
        # self.attention = nn.Linear(lstm_hidden_size * 2, 1)  # Bidirectional LSTM

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        model_output = self.model(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = model_output[0]  # Shape: (batch_size, seq_len, hidden_size)

        if self.pooling == 'default':
            return model_output[1]
        elif self.pooling == 'mean':
            return self.mean_pooling(hidden_states, attention_mask)
        elif self.pooling == 'attention':
            return self.attention_pooling(hidden_states, attention_mask)
    
    @staticmethod
    def mean_pooling(token_embeddings: torch.Tensor, attention_mask: torch.Tensor):
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    
    def attention_pooling(self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor):
        # Pass through LSTM
        lstm_output, _ = self.lstm(token_embeddings)
        
        # Compute attention scores
        attention_scores = self.attention(lstm_output).squeeze(-1)
        attention_scores = attention_scores.masked_fill(attention_mask == 0, float('-inf'))
        attention_weights = torch.softmax(attention_scores, dim=-1)
        
        # Compute weighted sum of token embeddings
        attention_weights = attention_weights.unsqueeze(-1)
        weighted_sum = torch.sum(lstm_output * attention_weights, dim=1)
        
        return weighted_sum

def get_model_tokenizer(model_name: str, **kwargs):
    return Model(model_name, **kwargs), AutoTokenizer.from_pretrained(model_name, use_fast=False)

In [20]:
BATCH = Tuple[Dict[str, torch.Tensor]]


def safe_mean(x, round_to=4):
    try:
        return round(mean(x), round_to)
    except StatisticsError:
        return None

def evaluate_datasets(datasets, model, tokenizer, run_path, save_vectors, batch_size):
    model.eval()  # Set model to evaluation mode
    results = {}

    vectorizer_path = run_path if save_vectors else None
    vct = PytorchVectorizer(dir_path=vectorizer_path, model=model, tokenizer=tokenizer, 
                            batch_size=batch_size, port_embeddings_to_cpu=True)

    with torch.no_grad():  # Disable gradient computation
        for dataset_name, dataset in datasets.items():
            result_gen = embedding_results(dataset, vct, vct, post_split_size=512, device='cuda')
            process_result_generator(result_gen, csv_path=os.path.join(run_path, dataset_name + '.csv'))
            
    return True
    

def predict_dataset(cfg, dev_datasets):
    
    try:
        if cfg.train.output_path is not None:
            run_path = os.path.join(cfg.train.output_path, cfg.timestamp)
            os.makedirs(run_path)
            # save config
            with open(os.path.join(run_path, 'config.yaml') ,'w') as f:
                yaml.dump(cfg.to_dict(), f, default_flow_style=False)
    except:
        pass

    try:
        model_path = os.path.join(run_path, cfg.model.name)
        os.makedirs(model_path)
    except:
        pass

    print('Get model')
    model, tokenizer = get_model_tokenizer(cfg.model.name, pooling=cfg.model.pooling)

    # if torch.cuda.device_count() > 1:
    #     print(f"Using {torch.cuda.device_count()} GPUs")
    #     model = torch.nn.DataParallel(model)

    print('Loading Model Weights')
    model.load_state_dict(torch.load(cfg.model.path, weights_only=True))

    model = model.to(cfg.train.device)

    print('eval done' , evaluate_datasets({**dev_datasets}, model, tokenizer, 
                                model_path, save_vectors=False, batch_size=cfg.train.batch_size))


In [21]:
'fra', 'spa', 'eng', 'por', 'tha', 'deu', 'msa', 'ara'

('fra', 'spa', 'eng', 'por', 'tha', 'deu', 'msa', 'ara')

In [22]:
# train_dataset = OurDataset(split='train', fold=2).load()

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# train(cfg, train_dataset, dev_dataset)

In [None]:
cfg = Config()
cfg.model.name = 'intfloat/multilingual-e5-large-instruct'
cfg.model.path = '/kaggle/input/model-weights/multilingual-e5-large-instruct_f0_mean.pt'
cfg.model.pooling = 'mean'
cfg.train.batch_size = 512
cfg.train.p_max_seq_length: int = 768
cfg.train.fc_max_seq_length: int = 768
dev_dataset = {'dev_cross': OurDataset(split='test', language = 'cross').load()}
predict_dataset(cfg, dev_dataset)

Loading fact-checks.
272447 loaded.
Loading posts.
8276 loaded.
post_lang
eng    145287
por     32598
spa     25440
ara     21153
tur     12536
        11567
pol      8796
deu      7485
fra      6316
msa       686
tha       583
Name: count, dtype: int64
task lang fact checks 272447
task lang post dev 4000
Filtering by split: 4000 posts remaining and sampled 272447 fact checks
Filtering fact-checks by language: 272447 posts remaining.
Filtering posts by language: 4000 posts remaining.
Filtering posts.
Posts remaining: 4000
Get model


config.json:   0%|          | 0.00/690 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.12G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/964 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.1M [00:00<?, ?B/s]

Loading Model Weights
Calculating embeddings for fact checks
Calculating 272256 vectors.


In [None]:
cfg = Config()
cfg.model.name = 'intfloat/multilingual-e5-large-instruct'
cfg.model.path = '/kaggle/input/model-weights/multilingual-e5-large-instruct_f1_mean.pt'
cfg.model.pooling = 'mean'
cfg.train.batch_size = 512
cfg.train.p_max_seq_length: int = 768
cfg.train.fc_max_seq_length: int = 768
dev_dataset = {'dev_cross': OurDataset(split='test', language = 'cross').load()}
predict_dataset(cfg, dev_dataset)