In [1]:
import os
import pickle
import math
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset

In [2]:
config = {
    'dataset': 'uklex18', # choices=['arxiv', 'drug', 'huffpost', 'mimic', 'fmow', 'yearbook']
    'method': 'erm', # choices=['er', 'coral', 'ensemble', 'ewc', 'ft', 'groupdro', 'irm', 'si', 'erm', 'simclr', 'swav', 'swa']
    'device': 0,  # 'gpu id'
    'random_seed': 1,  # 'random seed number'

    'eval_fix': False,

    # Training hyperparameters
    'train_update_iter': 1000,  # 'train update iter'
    'lr': 2e-05,  # 'the base learning rate of the generator'
    'momentum': 0.9,  # 'momentum'
    'weight_decay': 0.01,  # 'weight decay'
    'mini_batch_size': 60,  # 'mini batch size for SGD'
    'reduced_train_prop': None,  # 'proportion of samples allocated to train at each time step'
    'reduction': 'mean',
    'eval_freq': 50,
    'patience': 3,

    # Evaluation
    'offline': False,  # help='evaluate offline at a single time step split'
    'difficulty': False,  # 'task difficulty'
    # todo: set value of split_time
    'split_time': 2008,  # 'timestep to split ID vs OOD'
    'test_time': 2008,
    'eval_next_timestamps': 1,  # 'number of future timesteps to evaluate on'
    'eval_worst_time': False,  # 'evaluate worst timestep accuracy'
    'load_model': False,  # 'load trained model for evaluation only'
    'eval_metric': 'acc',  # choices=['acc', 'f1', 'rmse']
    'eval_all_timestamps': False,  # 'evaluate at ID and OOD time steps'

    # ER
    'replay_freq': 50,  # 'number of previous timesteps to finetune on'

    # GroupDRO
    'num_groups': 3,  # 'number of windows for Invariant Learning baselines'
    'group_size': 2,  # 'window size for Invariant Learning baselines'
    'non_overlapping': False,  # 'non-overlapping time windows'

    # EWC
    'ewc_lambda': 0.5,  # help='how strong to weigh EWC-loss ("regularisation strength")'
    'gamma': 1.0,  # help='decay-term for old tasks (contribution to quadratic term)'
    'online': True,  # help='"online" (=single quadratic term) or "offline" (=quadratic term per task) EWC'
    'fisher_n': None,  # help='sample size for estimating FI-matrix (if "None", full pass over dataset)'
    'emp_FI': False,  # help='if True, use provided labels to calculate FI ("empirical FI"); else predicted labels'

    # A-GEM
    'buffer_size': 10,  # 'buffer size for A-GEM'

    # CORAL
    'coral_lambda': 0.5,  # 'how strong to weigh CORAL loss'

    # IRM
    'irm_lambda': 1.0,  # 'how strong to weigh IRM penalty loss'
    'irm_penalty_anneal_iters': 0,  # 'number of iterations after which we anneal IRM penalty loss'

    # Logging, saving, and testing options
    'data_dir': './data',  # 'directory for datasets.'
    'log_dir': './checkpoints',  # 'directory for summaries and checkpoints.'
    'results_dir': './results',  # 'directory for summaries and checkpoints.'
    'num_workers': 0  # 'number of workers in data generator'
}
from munch import DefaultMunch
args = DefaultMunch.fromDict(config)

In [18]:
PREPROCESSED_FILE = 'uklex18.pkl'
MAX_TOKEN_LENGTH = 512
RAW_DATA_FILE = 'uk-lex18.jsonl'
ID_HELD_OUT = 0.2
GROUP = 7

In [19]:
raw_data_path = os.path.join(args.data_dir, RAW_DATA_FILE)
if not os.path.isfile(raw_data_path):
    raise ValueError(f'{RAW_DATA_FILE} is not in the data directory {args.data_dir}!')

# Load data frame from json file, group by year
base_df = pd.read_json(raw_data_path, lines=True)
base_df = base_df.sort_values(by=['year'])
base_df

Unnamed: 0,id,year,labels,title,body,data_type
0,UKSI19750515,1975,[SOCIAL SECURITY],The Social Security (Guardian's Allowances) Re...,"Citation, commencement and interpretation\n1 1...",train
1,UKSI19761267,1976,[SOCIAL SECURITY],The Child Benefit and Social Security (Fixing ...,"Citation, commencement and interpretation\n1 1...",train
2,UKSI19760965,1976,[SOCIAL SECURITY],The Child Benefit (General) Regulations 1976,"PART I\nGeneral\nCitation, commencement and in...",train
3,UKSI19790628,1979,[SOCIAL SECURITY],The Social Security (Claims and Payments) Regu...,P art I\nGENERAL\nCitation and commencement\n1...,train
4,UKSI19821163,1982,[TRANSPORTATION],The Motorways Traffic (England and Wales) Regu...,Commencement and citation\n1\nThese Regulation...,train
...,...,...,...,...,...,...
36450,UKSI20180626,2018,[SOCIAL SECURITY],The Scotland Act 1998 (Agency Arrangements) (S...,Citation\n1\nThis Order may be cited as the Sc...,test
36449,UKSI20180431,2018,[HEALTH CARE],The Plymouth Hospitals National Health Service...,"Citation, commencement and interpretation\n1\n...",test
36448,UKSI20180061,2018,[SOCIAL SECURITY],The Social Fund Funeral Expenses Amendment Reg...,Citation and commencement\n1\nThese Regulation...,test
36454,UKSI20180221,2018,[HOUSING],The Licensing of Houses in Multiple Occupation...,Citation and Commencement\n1\n1\nThis Order ma...,test


In [20]:
df_years = base_df.groupby(pd.Grouper(key='data_type'))
all_dfs = [group for _, group in df_years]
# all_years = list(base_df['data_type'].unique())
dfs = [all_dfs[0][:10000], all_dfs[0][10000:], all_dfs[1], all_dfs[2]]
# years = []
# dfs.append(pd.concat(all_dfs[:8]))
# years.append(all_years[7])
# all_dfs = all_dfs[8:]
# all_years = all_years[8:]
# for i in range(math.ceil(len(all_years)/GROUP)):
#     try:
#         dfs.append(pd.concat(all_dfs[GROUP*i:GROUP*i+GROUP]))
#         years.append(all_years[GROUP*i + 1])
#     except:
#         dfs.append(pd.concat(all_dfs[GROUP*i:]))
#         years.append(all_years[-1])

In [3]:
import nltk
import string
nltk.download('stopwords')
from nltk.corpus import stopwords
remove_these = set(stopwords.words('english') + list(string.punctuation) + list(string.digits))
from nltk.tokenize import RegexpTokenizer
nltk.download('punkt')
tokenizer = RegexpTokenizer(r'\w+')

[nltk_data] Downloading package stopwords to /Users/luke/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /Users/luke/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [4]:
import numpy as np
import pandas as pd
import seaborn as sns
from collections import Counter
import collections
import ast

In [None]:
labels_counts = {}
for i in range(4):
    dfs[i]['tokenized'] = dfs[i]['body'].apply(tokenizer.tokenize)
    labels = [label for lbs in dfs[i]['tokenized'] for label in lbs]
    all_labels = [w for w in labels if not w in remove_these]
    labels_count = Counter(all_labels)
    labels_counts[i+1] = labels_count

In [None]:
def jaccard_set(list1, list2):
    """Define Jaccard Similarity function for two sets"""
    intersection = len(list(set(list1).intersection(list2)))
    union = (len(list1) + len(list2)) - intersection
    return float(intersection) / union

mat = []
for k1, p in labels_counts.items():
    sub = []
    for k2, q in labels_counts.items():
        kl_pq = jaccard_set(p.keys(), q.keys())
        sub.append(kl_pq)
    mat.append(sub)
np.round(mat, 4)

In [None]:
allowed = labels_counts[1].keys() & labels_counts[2].keys() & labels_counts[3].keys() & labels_counts[4].keys()

for i in range(4):
    entries_to_remove = labels_counts[i+1].keys() - allowed
    for k in entries_to_remove:
        labels_counts[i+1].pop(k, None)
    od = collections.OrderedDict(sorted(labels_counts[i+1].items()))
    labels_counts[i+1] = list(od.values())

In [None]:
import matplotlib.pyplot as plt

for k, v in labels_counts.items():
    plt.plot(range(0, len(v)), v, '.-', label=k)
    # NOTE: changed `range(1, 4)` to mach actual values count
plt.legend()  # To draw legend
plt.show()

In [None]:
from math import log2
from scipy.special import rel_entr, kl_div
# calculate the jensen-shannon distance metric
from scipy.spatial.distance import jensenshannon
import numpy as np

# calculate the kl divergence
def kl_divergence(p, q):
	return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))

mat = []
for k1, p in labels_counts.items():
    sub = []
    for k2, q in labels_counts.items():
        # print('P =', k1, ', Q =', k2)
        # calculate (P || Q)
        p = np.asarray(p)
        q = np.asarray(q)
        kl_pq = jensenshannon(p, q, base=2)
        sub.append(kl_pq)
        # print('KL(P|Q): %.3f' % kl_pq)
        # # calculate (Q || P)
        # kl_qp = rel_entr(q, p)
        # print('KL(Q || P): %.3f bits' % sum(kl_pq))
    mat.append(sub)
np.round(mat, 4)

# ECTHR

In [24]:
PREPROCESSED_FILE = 'ecthr_a.pkl'
MAX_TOKEN_LENGTH = 128
RAW_DATA_FILE = ['ecthr-train.jsonl', 'ecthr-dev.jsonl', 'ecthr-test.jsonl']
ID_HELD_OUT = 0.2
GROUP = 1

In [26]:
base_dfs = []
for path in RAW_DATA_FILE:
    raw_data_path = os.path.join(args.data_dir, path)
    if not os.path.isfile(raw_data_path):
        raise ValueError(f'{path} is not in the data directory {args.data_dir}!')
    base_dfs.append(pd.read_json(raw_data_path, lines=True))
# Load data frame from json file, group by year
# base_df = pd.concat(base_dfs)
base_dfs[0] = base_dfs[0].sort_values(by=['judgment_date'])

dfs = [base_dfs[0][:4500], base_dfs[0][4500:], base_dfs[1], base_dfs[2]]

allowed = ['10', '11', '13', '14', '2', '3', '5', '6', '7', '8', '9', 'P1-1', 'P1-3', 'P4-2']
# # allowed = ['10', '11', '13', '14', '18', '2', '3', '4', '5', '6', '7', '8', '9', 'P1-1', 'P4-2', 'P7-1', 'P7-4']
all_dfs = []
for base_df in base_dfs:
    for i in range(len(base_df)):
        new_label = []
        for label in base_df.iloc[i, 8]:
            if label in allowed:
                new_label.append(label)
        base_df.iloc[i, 8] = new_label
    all_dfs.append(base_df)
# base_df['year'] = pd.DatetimeIndex(base_df['judgment_date']).year
# df_years = base_df.groupby(pd.Grouper(key='year'))
# all_dfs = [group for _, group in df_years]

In [27]:
labels_counts = {}
for i in range(4):
    dfs[i]['body'] = dfs[i]['facts'].apply(' '.join)
    dfs[i]['tokenized'] = dfs[i]['body'].apply(tokenizer.tokenize)
    labels = [label for lbs in dfs[i]['tokenized'] for label in lbs]
    all_labels = [w for w in labels if not w in remove_these]
    labels_count = Counter(all_labels)
    labels_counts[i+1] = labels_count

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dfs[i]['body'] = dfs[i]['facts'].apply(' '.join)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dfs[i]['tokenized'] = dfs[i]['body'].apply(tokenizer.tokenize)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dfs[i]['body'] = dfs[i]['facts'].apply(' '.join)
A value is trying to be set on a copy of a s

In [28]:
def jaccard_set(list1, list2):
    """Define Jaccard Similarity function for two sets"""
    intersection = len(list(set(list1).intersection(list2)))
    union = (len(list1) + len(list2)) - intersection
    return float(intersection) / union

mat = []
for k1, p in labels_counts.items():
    sub = []
    for k2, q in labels_counts.items():
        kl_pq = jaccard_set(p.keys(), q.keys())
        sub.append(kl_pq)
    mat.append(sub)
np.round(mat, 4)

array([[1.    , 0.3341, 0.2899, 0.291 ],
       [0.3341, 1.    , 0.2872, 0.2858],
       [0.2899, 0.2872, 1.    , 0.3686],
       [0.291 , 0.2858, 0.3686, 1.    ]])

In [30]:
allowed = labels_counts[1].keys() & labels_counts[2].keys() & labels_counts[3].keys() & labels_counts[4].keys()

for i in range(4):
    entries_to_remove = labels_counts[i+1].keys() - allowed
    for k in entries_to_remove:
        labels_counts[i+1].pop(k, None)
    od = collections.OrderedDict(sorted(labels_counts[i+1].items()))
    labels_counts[i+1] = list(od.values())

In [31]:
from math import log2
from scipy.special import rel_entr, kl_div
# calculate the jensen-shannon distance metric
from scipy.spatial.distance import jensenshannon
import numpy as np

# calculate the kl divergence
def kl_divergence(p, q):
	return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))

mat = []
for k1, p in labels_counts.items():
    sub = []
    for k2, q in labels_counts.items():
        # print('P =', k1, ', Q =', k2)
        # calculate (P || Q)
        p = np.asarray(p)
        q = np.asarray(q)
        kl_pq = jensenshannon(p, q, base=2)
        sub.append(kl_pq)
        # print('KL(P|Q): %.3f' % kl_pq)
        # # calculate (Q || P)
        # kl_qp = rel_entr(q, p)
        # print('KL(Q || P): %.3f bits' % sum(kl_pq))
    mat.append(sub)
np.round(mat, 4)

array([[0.    , 0.1729, 0.2312, 0.2359],
       [0.1729, 0.    , 0.1405, 0.1551],
       [0.2312, 0.1405, 0.    , 0.1458],
       [0.2359, 0.1551, 0.1458, 0.    ]])

In [None]:
import matplotlib.pyplot as plt

for k, v in labels_counts.items():
    plt.plot(range(0, len(v)), v, '.-', label=k)
    # NOTE: changed `range(1, 4)` to mach actual values count
plt.legend()  # To draw legend
plt.show()

# EURLEX

In [5]:
from datasets import load_dataset
dataset = load_dataset('multi_eurlex', language='en', label_level='level_1')

Found cached dataset multi_eurlex (/Users/luke/.cache/huggingface/datasets/multi_eurlex/default-label_level=level_1,language=en/1.0.0/5a12a7463045d4dcb12896b478c09b5a8a131a02b7e7bce059ba7ececc6584ee)


  0%|          | 0/3 [00:00<?, ?it/s]

In [11]:
import re
dfs = []

splits = ['train', 'validation', 'test']
for x in splits:
    df = pd.DataFrame(columns=['celex_id', 'text', 'labels', 'year'])
    for i in dataset[x]:
        valid_months = ["January", "February", "March", "April", "May", "June", "July", "August", "September",
                        "October", "November", "December"]
        pattern = r'(\d{1,2})\s*(' + '|'.join(valid_months) + ')\s*(\d{4})'
        # matches = list(re.finditer(pattern, i['title'], re.IGNORECASE))

        # if len(matches) > 0:
        #     year = matches[0].group(3)
        # else:
        #
        matches = list(re.finditer(pattern, i['text'], re.IGNORECASE))
        if len(matches) > 0:
            year = matches[0].group(3)
        # else:
            # print(i['celex_id'], i['text'][:100], '\n')
        elif i['celex_id'] == '31988R0091':
            year = 1988
        elif i['celex_id'] in ['31987D0594', '31987D0593']:
            year = 1987

        df.loc[len(df)] = list(i.values()) + [int(year)]
    df = df.sort_values(by=['year'])
    if x == 'train':
        dfs += [df[:27500], df[27500:]]
    else:
        dfs.append(df)

In [25]:
# dfs = [pd.DataFrame(dataset['train'])[:27500], pd.DataFrame(dataset['train'])[27500:], pd.DataFrame(dataset['validation']), pd.DataFrame(dataset['test'])]

labels_counts = {}
for i in range(4):
    dfs[i]['tokenized'] = dfs[i]['text'].apply(tokenizer.tokenize)
    labels = [label for lbs in dfs[i]['tokenized'] for label in lbs]
    all_labels = [w for w in labels if not w in remove_these]
    labels_count = Counter(all_labels)
    labels_counts[i+1] = labels_count

In [26]:
def jaccard_set(list1, list2):
    """Define Jaccard Similarity function for two sets"""
    intersection = len(list(set(list1).intersection(list2)))
    union = (len(list1) + len(list2)) - intersection
    return float(intersection) / union

mat = []
for k1, p in labels_counts.items():
    sub = []
    for k2, q in labels_counts.items():
        kl_pq = jaccard_set(p.keys(), q.keys())
        sub.append(kl_pq)
    mat.append(sub)
np.round(mat, 4)

array([[1.    , 0.3134, 0.2277, 0.2374],
       [0.3134, 1.    , 0.2809, 0.2859],
       [0.2277, 0.2809, 1.    , 0.4333],
       [0.2374, 0.2859, 0.4333, 1.    ]])

In [27]:
allowed = labels_counts[1].keys() & labels_counts[2].keys() & labels_counts[3].keys() & labels_counts[4].keys()

for i in range(4):
    entries_to_remove = labels_counts[i+1].keys() - allowed
    for k in entries_to_remove:
        labels_counts[i+1].pop(k, None)
    od = collections.OrderedDict(sorted(labels_counts[i+1].items()))
    labels_counts[i+1] = list(od.values())

In [28]:
from math import log2
from scipy.special import rel_entr, kl_div
# calculate the jensen-shannon distance metric
from scipy.spatial.distance import jensenshannon
import numpy as np


# calculate the kl divergence
def kl_divergence(p, q):
    return sum(p[i] * log2(p[i] / q[i]) for i in range(len(p)))


mat = []
for k1, p in labels_counts.items():
    sub = []
    for k2, q in labels_counts.items():
        # print('P =', k1, ', Q =', k2)
        # calculate (P || Q)
        p = np.asarray(p)
        q = np.asarray(q)
        kl_pq = jensenshannon(p, q, base=2)
        sub.append(kl_pq)
        # print('KL(P|Q): %.3f' % kl_pq)
        # # calculate (Q || P)
        # kl_qp = rel_entr(q, p)
        # print('KL(Q || P): %.3f bits' % sum(kl_pq))
    mat.append(sub)
np.round(mat, 4)

array([[0.    , 0.2391, 0.3321, 0.3592],
       [0.2391, 0.    , 0.2167, 0.2648],
       [0.3321, 0.2167, 0.    , 0.1805],
       [0.3592, 0.2648, 0.1805, 0.    ]])

In [32]:
for i in range(4):
    print('\n'.join(np.round(mat, 4)[:, i].astype(str)), '\n')

0.0
0.1729
0.2312
0.2359 

0.1729
0.0
0.1405
0.1551 

0.2312
0.1405
0.0
0.1458 

0.2359
0.1551
0.1458
0.0 



In [24]:
dfs[2]

Unnamed: 0,celex_id,text,labels,year
3406,32010D0491,COUNCIL DECISION\nof 27 July 2009\non the sign...,"[4, 11, 2, 5, 3, 15]",2009
2499,32010R0476,COMMISSION REGULATION (EU) No 476/2010\nof 31 ...,"[3, 17, 15]",2010
1845,32010R1183,COMMISSION REGULATION (EU) No 1183/2010\nof 14...,"[2, 17, 6]",2010
1844,32010R1194,COMMISSION REGULATION (EU) No 1194/2010\nof 14...,"[11, 8, 18, 6]",2010
1843,32010D0806,DECISION OF THE EUROPEAN PARLIAMENT AND OF THE...,"[4, 19, 9, 18, 15]",2010
...,...,...,...,...
2613,32012R0122,COMMISSION IMPLEMENTING REGULATION (EU) No 122...,"[0, 3, 17, 6]",2012
4173,32012D0483,COMMISSION DECISION\nof 20 August 2012\nsettin...,"[3, 6]",2012
0,32012R0782,COMMISSION IMPLEMENTING REGULATION (EU) No 782...,"[3, 2, 17, 6]",2012
3136,32013L0060,COMMISSION DIRECTIVE 2013/60/EU\nof 27 Novembe...,"[7, 3, 8]",2013


In [None]:
dfs = [pd.DataFrame(dataset['train'])[:27500], pd.DataFrame(dataset['train'])[27500:], pd.DataFrame(dataset['validation']), pd.DataFrame(dataset['test'])]

labels_counts = {}
for i in range(4):
    all_labels = [label for lbs in dfs[i]['labels'] for label in lbs]
    labels_count = Counter(all_labels)
    for k in labels_count.keys():
        labels_count[k] = labels_count[k]/sum(labels_count.values())
    if i == 2:
        labels_count[66] = 0
    od = collections.OrderedDict(sorted(labels_count.items()))
    labels_counts[i+1] = list(od.values())

from math import log2
from scipy.special import rel_entr, kl_div
# calculate the jensen-shannon distance metric
from scipy.spatial.distance import jensenshannon
import numpy as np

# calculate the kl divergence
def kl_divergence(p, q):
	return sum(p[i] * log2(p[i]/q[i]) for i in range(len(p)))

mat = []
for k1, p in labels_counts.items():
    sub = []
    for k2, q in labels_counts.items():
        # print('P =', k1, ', Q =', k2)
        # calculate (P || Q)
        p = np.asarray(p)
        q = np.asarray(q)
        kl_pq = jensenshannon(p, q, base=2)
        sub.append(kl_pq)
        # print('KL(P|Q): %.3f' % kl_pq)
        # # calculate (Q || P)
        # kl_qp = rel_entr(q, p)
        # print('KL(Q || P): %.3f bits' % sum(kl_pq))
    mat.append(sub)
np.round(mat, 4)

In [None]:
import matplotlib.pyplot as plt

for k, v in labels_counts.items():
    plt.plot(range(0, len(v)), v, '.-', label=k)
    # NOTE: changed `range(1, 4)` to mach actual values count
plt.legend()  # To draw legend
plt.show()