# Few shot learning

Here, in this notebook, we will look into few shot / zero shot learning to fit a model

https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8693837

# Deep learning

Deep learnign methods have achieved great success across several domains and tasks in the past few years. However, these supervised learning methods pose a great demands for large amount of data.


If we don't have enough labelled data, we can address these problems with:

* Zero shot learning

* One shot learning

* few shot learning

# Implementation

In [None]:
import pandas as pd
import numpy as np
from random import seed
from random import sample

seed(42)
np.random.seed(42)

from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

import gensim.downloader as api
from gensim.models.keyedvectors import Word2VecKeyedVectors

from sklearn.decomposition import PCA
from sklearn.metrics import accuracy_score
from scipy import spatial

from nltk.corpus import stopwords

In [None]:
#evalution metrics
import nltk
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')

#linear algebra,data preprocessing,Csv files
import pandas as pd
import nltk
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


#for data cleaning
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import re
import string


from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix

#for feature selection
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfTransformer

#for classification
from sklearn.naive_bayes import MultinomialNB
from sklearn.naive_bayes import MultinomialNB
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import SGDClassifier
from sklearn.pipeline import Pipeline

#model selection
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.compose import make_column_transformer
from sklearn.preprocessing import FunctionTransformer

[nltk_data] Downloading package punkt to /Users/puneet/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/puneet/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /Users/puneet/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [None]:
df_train = pd.read_csv("../data/train.csv")
df_test = pd.read_csv("../data/test.csv")

In [None]:
df_train.shape, df_test.shape

((120000, 3), (7600, 3))

In [None]:
df_train.head()

Unnamed: 0,Class Index,Title,Description
0,3,Wall St. Bears Claw Back Into the Black (Reuters),"Reuters - Short-sellers, Wall Street's dwindli..."
1,3,Carlyle Looks Toward Commercial Aerospace (Reu...,Reuters - Private investment firm Carlyle Grou...
2,3,Oil and Economy Cloud Stocks' Outlook (Reuters),Reuters - Soaring crude prices plus worries\ab...
3,3,Iraq Halts Oil Exports from Main Southern Pipe...,Reuters - Authorities have halted oil export\f...
4,3,"Oil prices soar to all-time record, posing new...","AFP - Tearaway world oil prices, toppling reco..."


In [None]:
df_train['text'] = df_train['Title'] + "  " + df_train['Description']
df_test['text']  = df_test['Title'] + " " + df_test['Description']

In [None]:
df_train.drop(['Title','Description'],1, inplace=True)
df_test.drop(['Title','Description'],1, inplace=True)

In [None]:
categories = {1:'World News', 2:'Sports News', 3:'Business News', 4:'Science-Technology News'}

df_train['category'] = df_train['Class Index'].map(categories)
df_test['category'] = df_test['Class Index'].map(categories)

df_train = df_train.drop(columns=['Class Index'])
df_test = df_test.drop(columns=['Class Index'])

In [None]:
list(categories.values())

['World News', 'Sports News', 'Business News', 'Science-Technology News']

In [None]:
df_train.head()

Unnamed: 0,text,category
0,Wall St. Bears Claw Back Into the Black (Reute...,Business News
1,Carlyle Looks Toward Commercial Aerospace (Reu...,Business News
2,Oil and Economy Cloud Stocks' Outlook (Reuters...,Business News
3,Iraq Halts Oil Exports from Main Southern Pipe...,Business News
4,"Oil prices soar to all-time record, posing new...",Business News


# Data Cleaning

In [None]:

def remove_punc(text):
    text = re.sub('\[.*?\]', '', text)
    text = re.sub('https?://\S+|www\.\S+', '', text)
    text = re.sub('<.*?>+', '', text)
    text = re.sub('[%s]' % re.escape(string.punctuation), '', text)
    text = re.sub('\n', '', text)
    text = re.sub('\w*\d\w*', '', text)
    return text

#normalizing case

def normalize(text):        
    lower_case = text.lower()
    tokens=word_tokenize(lower_case)
    return (" ".join(tokens)).strip()


nltk_stop_words = nltk.corpus.stopwords.words('english')
def remove_stop(text):        
    word_list=[word for word in text.split() if word not in nltk_stop_words]
    return " ".join(word_list)

lemmatizer = WordNetLemmatizer()
def lemma(text): 
    import pdb
    lemmas = [lemmatizer.lemmatize(word) for word in text.split()]
    return " ".join(lemmas)



In [None]:
df_train['text'] = df_train['text'].apply(remove_punc,1)
df_test['text'] = df_test['text'].apply(remove_punc,1)

# Feature Extraction

In [None]:
from sklearn.preprocessing import FunctionTransformer

def pipelinize(function, active=True):
    def list_comprehend_a_function(list_or_series, active=True):
        if active:
            return [function(i) for i in list_or_series]
        else: # if it's not active, just pass it right back
            return list_or_series
    return FunctionTransformer(list_comprehend_a_function, validate=False, kw_args={'active':active})

In [None]:
text_processing = make_pipeline(
    *[ pipelinize(f) for f in [remove_punc,normalize, remove_stop, lemma]],
    CountVectorizer(), 
    TfidfTransformer())


pre_processing = make_column_transformer(
    (text_processing, 'text'),
    remainder='drop'
)

# Model 

### Naive bayes

In [None]:
model = make_pipeline(
    pre_processing, 
     MultinomialNB()
)
model.fit(df_train, df_train.category)

Pipeline(steps=[('columntransformer',
                 ColumnTransformer(transformers=[('pipeline',
                                                  Pipeline(steps=[('functiontransformer-1',
                                                                   FunctionTransformer(func=<function pipelinize.<locals>.list_comprehend_a_function at 0x1e1500a60>,
                                                                                       kw_args={'active': True})),
                                                                  ('functiontransformer-2',
                                                                   FunctionTransformer(func=<function pipelinize.<locals>.list_comprehend_a_function at...
                                                                   FunctionTransformer(func=<function pipelinize.<locals>.list_comprehend_a_function at 0x1e26dd9d0>,
                                                                                       kw_args={'active': True})),

In [None]:
predicted_category = model.predict(df_test)

In [None]:
print(classification_report(df_test.category,predicted_category))

                         precision    recall  f1-score   support

          Business News       0.86      0.86      0.86      1900
Science-Technology News       0.88      0.87      0.88      1900
            Sports News       0.95      0.98      0.96      1900
             World News       0.91      0.89      0.90      1900

               accuracy                           0.90      7600
              macro avg       0.90      0.90      0.90      7600
           weighted avg       0.90      0.90      0.90      7600



### SVM classifier

In [None]:
model = make_pipeline(
    pre_processing, 
     SGDClassifier(loss='hinge', penalty='l2',
                                           alpha=1e-3,random_state=42)
)
model.fit(df_train, df_train.category)

Pipeline(steps=[('columntransformer',
                 ColumnTransformer(transformers=[('pipeline',
                                                  Pipeline(steps=[('functiontransformer-1',
                                                                   FunctionTransformer(func=<function pipelinize.<locals>.list_comprehend_a_function at 0x1e1500a60>,
                                                                                       kw_args={'active': True})),
                                                                  ('functiontransformer-2',
                                                                   FunctionTransformer(func=<function pipelinize.<locals>.list_comprehend_a_function at...
                                                                   FunctionTransformer(func=<function pipelinize.<locals>.list_comprehend_a_function at 0x1e26dd9d0>,
                                                                                       kw_args={'active': True})),

In [None]:
predicted_category = model.predict(df_test)

In [None]:
print(classification_report(df_test.category,predicted_category))

                         precision    recall  f1-score   support

          Business News       0.86      0.83      0.85      1900
Science-Technology News       0.88      0.83      0.85      1900
            Sports News       0.88      0.98      0.93      1900
             World News       0.90      0.88      0.89      1900

               accuracy                           0.88      7600
              macro avg       0.88      0.88      0.88      7600
           weighted avg       0.88      0.88      0.88      7600



# Training a neural network

In [None]:
from flair.data import Corpus
from flair.datasets import TREC_6
from flair.embeddings import WordEmbeddings, FlairEmbeddings, DocumentRNNEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from flair.models.text_classification_model import TARSClassifier
from flair.data import Sentence


In [None]:
def create_flair_dataset(row):
    return Sentence(row['text']).add_label('news',row['category'])

res_train = df_train.apply(create_flair_dataset,1)
res_test = df_test.apply(create_flair_dataset,1)

corpus = Corpus(train = res_train.values.tolist(), test = res_test.values.tolist())

In [None]:
corpus.train[0]

Sentence: "Carlyle Looks Toward Commercial Aerospace Reuters Reuters Private investment firm Carlyle Groupwhich has a reputation for making welltimed and occasionallycontroversial plays in the defense industry has quietly placedits bets on another part of the market"   [− Tokens: 35  − Sentence-Labels: {'news': [Business News (1.0)]}]

In [None]:
# 2. what tag do we want to predict?
tag_type = 'news'

# 3. make the tag dictionary from the corpus
# 2. create the label dictionary
label_dict = corpus.make_label_dictionary()

# 4. initialize embeddings
embedding_types = [

    WordEmbeddings('glove'),

    # comment in this line to use character embeddings
    # CharacterEmbeddings(),

    # comment in these lines to use flair embeddings
    # FlairEmbeddings('news-forward'),
    # FlairEmbeddings('news-backward'),
]

# 4. initialize document embedding by passing list of word embeddings
# Can choose between many RNN types (GRU by default, to change use rnn_type parameter)
document_embeddings = DocumentRNNEmbeddings(embedding_types, hidden_size=256)

# 5. create the text classifier
classifier = TextClassifier(document_embeddings, label_dictionary=label_dict)


# 6. initialize the text classifier trainer
trainer = ModelTrainer(classifier, corpus)

# 7. start the training
trainer.train('resources/taggers/trec',
              learning_rate=0.1,
              mini_batch_size=32,
              anneal_factor=0.5,
              patience=5,
              max_epochs=150)

In [None]:
classifier = TextClassifier.load('resources/taggers/trec/final-model.pt')

# create example sentence
sentence = Sentence('Who built the Eiffel Tower ?')

# predict class and print
classifier.predict(sentence)

print(sentence.labels)

# Few shot learning

In [None]:
from flair.trainers import ModelTrainer

# 1. load base TARS
tars = TARSClassifier.load('tars-base')


2021-03-13 20:31:27,381 loading file /Users/puneet/.flair/models/tars-base-v8.pt
init TARS


In [None]:
examples_to_predict = 100

sample_df_test = df_test.sample(n=examples_to_predict)
#sample_df_test = df_test

def predict(row):
    sentence = Sentence(row['text'])
    tars.predict_zero_shot(sentence, list(categories.values()))
    try:
        row['predict'] = sentence.get_labels()[0].value
    except:
        row['predict'] = ""
    return row

sample_df_test_pred = sample_df_test.apply(predict, 1)

In [None]:
sample_df_test_pred

Unnamed: 0,text,category,predict
3518,Taiwans Leader Urges China to Begin Talks AP A...,World News,
4158,Colgate Profit Falls on Higher Costs NEW YORK...,Business News,Business News
895,UPDATE Sons Of Gwalia In Administration On Hed...,Business News,Business News
2309,Latham stands by Bali claims candidate Federal...,World News,World News
3335,Analysts See PostStern Ripple Effect Howard St...,Business News,Business News
...,...,...,...
2250,Pantano replaced by Glock Jordan have terminat...,Sports News,Sports News
3968,US consumers unaware of spyware The findings c...,Business News,
3763,Asian Shares Hit by Metals Tumble Oil Reuters ...,Business News,Business News
6225,Favre Does It Again With the Texans nursing a ...,Sports News,Sports News


It got 73 % accuracy. SVM had 50% accuracy

In [None]:
print(classification_report(sample_df_test_pred.category,sample_df_test_pred.predict))

                         precision    recall  f1-score   support

                              0.00      0.00      0.00         0
          Business News       0.89      0.91      0.90        34
Science-Technology News       1.00      0.29      0.44        21
            Sports News       1.00      1.00      1.00        22
             World News       1.00      0.61      0.76        23

               accuracy                           0.73       100
              macro avg       0.78      0.56      0.62       100
           weighted avg       0.96      0.73      0.79       100



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


# Few shot learning

In [None]:
# N-way , K Samples training

def get_model(classes = list(categories.values()), k=3):

    # 1. load base TARS
    tars = TARSClassifier.load('tars-base')

    
    # get examples from training data
    sample_df = df_train.groupby('category').apply(lambda x: x.sample(n=k)).sample(frac=1).reset_index(drop=True)
    print(sample_df)
    print("size of training data is {}".format(sample_df.shape))
    
    sample_res_train = sample_df.apply(create_flair_dataset,1)
    sample_res_test = sample_df.apply(create_flair_dataset,1)

    new_corpus = Corpus(train = sample_res_train.values.tolist(), test = sample_res_test.values.tolist())
    print("Corpus ready to load: Train {} , Test: {} ".format(len(new_corpus.train), len(new_corpus.test)))
    
    # 3. make the model aware of the desired set of labels from the new corpus
    tars.add_and_switch_to_new_task( "NEWS_CLASSIFICATION",
                                     label_dictionary=new_corpus.make_label_dictionary())
    # 4. initialize the text classifier trainer
    trainer = ModelTrainer(tars, new_corpus)

    # 5. start the training
    trainer.train(base_path='resources/taggers/go_emotions', # path to store the model artifacts
              learning_rate=0.02, # use very small learning rate
              mini_batch_size=16,
              mini_batch_chunk_size=4, # optionally set this if transformer is too much for your machine
              max_epochs=10, # terminate after 10 epochs
              )
    
    return tars

# K = 1

Number of examples per class.

In [None]:
tars = get_model(k=1)

2021-03-13 21:01:19,716 loading file /Users/puneet/.flair/models/tars-base-v8.pt
init TARS
                                                text                 category
0  Russian Ministries Start Agreeing to Kyoto App...               World News
1  Team of Mystery  Many around the NFL say they ...              Sports News
2  Voq smartphone arrives in US  With the economy...  Science-Technology News
3  Bush stands up for strong dollar  President Ge...            Business News
size of training data is (4, 2)
Corpus ready to load: Train 4 , Test: 4 
2021-03-13 21:01:23,462 Computing label dictionary. Progress:


100%|██████████| 8/8 [00:00<00:00, 16735.38it/s]

2021-03-13 21:01:23,466 [b'World News', b'Sports News', b'Science-Technology News', b'Business News']
2021-03-13 21:01:23,468 ----------------------------------------------------------------------------------------------------
2021-03-13 21:01:23,470 Model: "TARSClassifier(
  (document_embeddings): None
  (decoder): None
  (loss_function): None
  (tars_model): TextClassifier(
    (document_embeddings): TransformerDocumentEmbeddings(
      (model): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
             




2021-03-13 21:01:25,305 epoch 1 - iter 1/1 - loss 0.72547483 - samples/sec: 9.24 - lr: 0.020000
2021-03-13 21:01:25,307 ----------------------------------------------------------------------------------------------------
2021-03-13 21:01:25,308 EPOCH 1 done: loss 0.7255 - lr 0.0200000
2021-03-13 21:01:25,309 BAD EPOCHS (no improvement): 0
saving best model
2021-03-13 21:01:26,089 ----------------------------------------------------------------------------------------------------
2021-03-13 21:01:27,934 epoch 2 - iter 1/1 - loss 0.10859243 - samples/sec: 8.96 - lr: 0.020000
2021-03-13 21:01:27,935 ----------------------------------------------------------------------------------------------------
2021-03-13 21:01:27,936 EPOCH 2 done: loss 0.1086 - lr 0.0200000
2021-03-13 21:01:27,937 BAD EPOCHS (no improvement): 1
2021-03-13 21:01:27,938 ----------------------------------------------------------------------------------------------------
2021-03-13 21:01:30,100 epoch 3 - iter 1/1 - loss 

In [None]:
sample_df_test_pred = sample_df_test.apply(predict, 1)

In [None]:
print(classification_report(sample_df_test_pred.category,sample_df_test_pred.predict))

                         precision    recall  f1-score   support

                              0.00      0.00      0.00         0
          Business News       0.88      0.88      0.88        34
Science-Technology News       0.88      0.33      0.48        21
            Sports News       1.00      0.95      0.98        22
             World News       0.83      0.83      0.83        23

               accuracy                           0.77       100
              macro avg       0.72      0.60      0.63       100
           weighted avg       0.89      0.77      0.81       100



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


# K= 2

Number of examples per class

In [None]:
tars = get_model(k=2)
sample_df_test_pred = sample_df_test.apply(predict, 1)
print(classification_report(sample_df_test_pred.category,sample_df_test_pred.predict))

2021-03-13 21:06:53,375 loading file /Users/puneet/.flair/models/tars-base-v8.pt
init TARS
                                                text                 category
0  Zafi worm proves a holiday pest  The massmaili...  Science-Technology News
1  Former Steelers Maine player Strzelczyk dies a...              Sports News
2  Grower Suggests Opening Your Mind to More Open...            Business News
3   Americans and Briton Are Kidnapped by Rebels ...               World News
4  Mars rovers roll on with new funding  NASA has...  Science-Technology News
5   Former Kmart Execs Charged with Fraud   WASHI...            Business News
6  Palestinian officials rush to bedside of ailin...               World News
7  Lehmann howlers rob Arsenal  Lehmann who was a...              Sports News
size of training data is (8, 2)
Corpus ready to load: Train 7 , Test: 8 
2021-03-13 21:06:57,360 Computing label dictionary. Progress:


100%|██████████| 15/15 [00:00<00:00, 16894.35it/s]

2021-03-13 21:06:57,364 [b'Science-Technology News', b'Sports News', b'Business News', b'World News']
2021-03-13 21:06:57,366 ----------------------------------------------------------------------------------------------------
2021-03-13 21:06:57,369 Model: "TARSClassifier(
  (document_embeddings): None
  (decoder): None
  (loss_function): None
  (tars_model): TextClassifier(
    (document_embeddings): TransformerDocumentEmbeddings(
      (model): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
             




2021-03-13 21:07:01,372 epoch 1 - iter 1/1 - loss 0.00358509 - samples/sec: 4.09 - lr: 0.020000
2021-03-13 21:07:01,374 ----------------------------------------------------------------------------------------------------
2021-03-13 21:07:01,375 EPOCH 1 done: loss 0.0036 - lr 0.0200000
2021-03-13 21:07:01,533 DEV : loss 0.02817917801439762 - score 1.0
2021-03-13 21:07:01,534 BAD EPOCHS (no improvement): 0
saving best model
2021-03-13 21:07:02,279 ----------------------------------------------------------------------------------------------------
2021-03-13 21:07:07,998 epoch 2 - iter 1/1 - loss 0.00418822 - samples/sec: 2.83 - lr: 0.020000
2021-03-13 21:07:08,000 ----------------------------------------------------------------------------------------------------
2021-03-13 21:07:08,001 EPOCH 2 done: loss 0.0042 - lr 0.0200000
2021-03-13 21:07:08,167 DEV : loss 0.024391956627368927 - score 1.0
2021-03-13 21:07:08,168 BAD EPOCHS (no improvement): 0
saving best model
2021-03-13 21:07:08,92

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
tars = get_model(k=3)
sample_df_test_pred = sample_df_test.apply(predict, 1)
print(classification_report(sample_df_test_pred.category,sample_df_test_pred.predict))

2021-03-13 21:09:46,329 loading file /Users/puneet/.flair/models/tars-base-v8.pt
init TARS
                                                 text                 category
0   CL Preview  Man UnitedSparta Prague  United ho...              Sports News
1   Beslan children return to school  Schools in t...               World News
2   Amnesty China Arrests Jails Human Rights Defen...               World News
3   Rockies Pitcher Frets Health Not Baseball AP  ...              Sports News
4   Starbucks Profit Climbs Extra Week Helps Reute...            Business News
5   In Spain a missing link  NEW YORK Scientists i...  Science-Technology News
6   Metcalfe Allen back ZigBee startup  Ember a st...  Science-Technology News
7   Powell to say Thursday if Darfur deaths are ge...               World News
8   When dreamers wield hammers  By day Lorenzo Ma...            Business News
9   Electronic Voting Raises New Issues  Touted as...  Science-Technology News
10  Police ID Officer in Red Sox Fan Dea

100%|██████████| 23/23 [00:00<00:00, 29483.19it/s]

2021-03-13 21:09:50,050 [b'Sports News', b'World News', b'Business News', b'Science-Technology News']
2021-03-13 21:09:50,052 ----------------------------------------------------------------------------------------------------
2021-03-13 21:09:50,055 Model: "TARSClassifier(
  (document_embeddings): None
  (decoder): None
  (loss_function): None
  (tars_model): TextClassifier(
    (document_embeddings): TransformerDocumentEmbeddings(
      (model): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
             




2021-03-13 21:09:55,572 epoch 1 - iter 1/1 - loss 0.39624599 - samples/sec: 2.95 - lr: 0.020000
2021-03-13 21:09:55,574 ----------------------------------------------------------------------------------------------------
2021-03-13 21:09:55,574 EPOCH 1 done: loss 0.3962 - lr 0.0200000
2021-03-13 21:09:55,807 DEV : loss 0.01594771072268486 - score 1.0
2021-03-13 21:09:55,808 BAD EPOCHS (no improvement): 0
saving best model
2021-03-13 21:09:56,583 ----------------------------------------------------------------------------------------------------
2021-03-13 21:10:03,876 epoch 2 - iter 1/1 - loss 0.02269825 - samples/sec: 2.22 - lr: 0.020000
2021-03-13 21:10:03,879 ----------------------------------------------------------------------------------------------------
2021-03-13 21:10:03,880 EPOCH 2 done: loss 0.0227 - lr 0.0200000
2021-03-13 21:10:04,061 DEV : loss 0.012076146900653839 - score 1.0
2021-03-13 21:10:04,062 BAD EPOCHS (no improvement): 0
saving best model
2021-03-13 21:10:04,81

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
tars = get_model(k=10)
sample_df_test_pred = sample_df_test.apply(predict, 1)
print(classification_report(sample_df_test_pred.category,sample_df_test_pred.predict))

2021-03-13 21:12:03,709 loading file /Users/puneet/.flair/models/tars-base-v8.pt
init TARS
                                                 text                 category
0   Hamas Israel behind Damascus bombing  The Pale...               World News
1   Sprint to Cut  Jobs Reuters  Reuters  Sprint C...  Science-Technology News
2   Former MVP Caminiti Dies of Heart Attack at  R...              Sports News
3   As media darling Conte  pushing it  No one kno...              Sports News
4   Google Investors Await the Dropping of  Millio...  Science-Technology News
5   Euro Disney shareholders back capital increase...            Business News
6   Polls show a tough fight PM  POLLS out today g...               World News
7   Victorious Iraqi forces patrol Samarra  SAMARR...               World News
8   Reborn WorldCom in search of  buyer  MCI the t...            Business News
9   Eye on IT  Sometimes if you make believe somet...  Science-Technology News
10  FDA Encourages Radio Tags on Drug Bo

100%|██████████| 76/76 [00:00<00:00, 44340.95it/s]

2021-03-13 21:12:07,689 [b'World News', b'Sports News', b'Science-Technology News', b'Business News']
2021-03-13 21:12:07,691 ----------------------------------------------------------------------------------------------------
2021-03-13 21:12:07,694 Model: "TARSClassifier(
  (document_embeddings): None
  (decoder): None
  (loss_function): None
  (tars_model): TextClassifier(
    (document_embeddings): TransformerDocumentEmbeddings(
      (model): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
             




2021-03-13 21:12:17,670 epoch 1 - iter 1/3 - loss 0.69966990 - samples/sec: 1.62 - lr: 0.020000
2021-03-13 21:12:28,389 epoch 1 - iter 2/3 - loss 0.38328363 - samples/sec: 1.49 - lr: 0.020000
2021-03-13 21:12:31,193 epoch 1 - iter 3/3 - loss 0.25881933 - samples/sec: 5.71 - lr: 0.020000
2021-03-13 21:12:31,195 ----------------------------------------------------------------------------------------------------
2021-03-13 21:12:31,196 EPOCH 1 done: loss 0.2588 - lr 0.0200000
2021-03-13 21:12:32,493 DEV : loss 0.19850775599479675 - score 1.0
2021-03-13 21:12:32,495 BAD EPOCHS (no improvement): 0
saving best model
2021-03-13 21:12:33,273 ----------------------------------------------------------------------------------------------------
2021-03-13 21:12:47,065 epoch 2 - iter 1/3 - loss 0.54060596 - samples/sec: 1.17 - lr: 0.020000
2021-03-13 21:12:55,651 epoch 2 - iter 2/3 - loss 0.27841292 - samples/sec: 1.86 - lr: 0.020000
2021-03-13 21:12:57,779 epoch 2 - iter 3/3 - loss 0.19700600 - sa

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
tars = get_model(k=30)
sample_df_test_pred = sample_df_test.apply(predict, 1)
print(classification_report(sample_df_test_pred.category,sample_df_test_pred.predict))

2021-03-13 21:31:35,662 loading file /Users/puneet/.flair/models/tars-base-v8.pt
init TARS
                                                  text  \
0    Bosnian Serb prime minister resigns  The prime...   
1    EU to lift trade sanctions against US amid lin...   
2      critical  flaws fixed in RealPlayer  RealNet...   
3    CHELSEA DUO PAY CREDIT TO PSG  Chelsea striker...   
4    Soyuz spacecraft docks with ISS  MOSCOW Oct  I...   
..                                                 ...   
115  Apple Launches ITunes Music Store in Canada Re...   
116  Titan  Drumroll Please  Imagine an oil drum th...   
117  Sun stands between Liberty and finals  NEW YOR...   
118  Barrera earns majority decision over Morales i...   
119  Bush Kerry Tentatively OK Three Debates  NEW Y...   

                    category  
0                 World News  
1              Business News  
2    Science-Technology News  
3                Sports News  
4    Science-Technology News  
..                       .

100%|██████████| 228/228 [00:00<00:00, 39854.19it/s]

2021-03-13 21:31:40,606 [b'World News', b'Business News', b'Science-Technology News', b'Sports News']
2021-03-13 21:31:40,608 ----------------------------------------------------------------------------------------------------
2021-03-13 21:31:40,611 Model: "TARSClassifier(
  (document_embeddings): None
  (decoder): None
  (loss_function): None
  (tars_model): TextClassifier(
    (document_embeddings): TransformerDocumentEmbeddings(
      (model): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
             




2021-03-13 21:31:49,224 epoch 1 - iter 1/7 - loss 0.87146622 - samples/sec: 1.88 - lr: 0.020000
2021-03-13 21:31:58,037 epoch 1 - iter 2/7 - loss 0.77223769 - samples/sec: 1.82 - lr: 0.020000
2021-03-13 21:32:10,443 epoch 1 - iter 3/7 - loss 0.56228625 - samples/sec: 1.29 - lr: 0.020000
2021-03-13 21:32:20,458 epoch 1 - iter 4/7 - loss 0.57051513 - samples/sec: 1.60 - lr: 0.020000
2021-03-13 21:32:30,134 epoch 1 - iter 5/7 - loss 0.55605755 - samples/sec: 1.65 - lr: 0.020000
2021-03-13 21:32:38,577 epoch 1 - iter 6/7 - loss 0.46554534 - samples/sec: 1.90 - lr: 0.020000
2021-03-13 21:32:45,552 epoch 1 - iter 7/7 - loss 0.42598019 - samples/sec: 2.29 - lr: 0.020000
2021-03-13 21:32:45,553 ----------------------------------------------------------------------------------------------------
2021-03-13 21:32:45,554 EPOCH 1 done: loss 0.4260 - lr 0.0200000
2021-03-13 21:32:48,123 DEV : loss 0.013242833316326141 - score 1.0
2021-03-13 21:32:48,125 BAD EPOCHS (no improvement): 0
saving best mod

# Fast ai reference

https://github.com/Daammon/AG-News-Classifier/blob/master/Ag_news_Classif.ipynb

In [None]:
data = {'k':[1,2,3, 10, 30 ]: 'accuracy':[77,79, 83, 87 ]}