<a href="https://colab.research.google.com/github/spencervagg99/ArticleSummarizer/blob/master/t5_medical_search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# T5 Medical Search

This notebook will run through the setup and training of a T5 model that will be used to identify the labels for medical gloves and hopefully the labels for future categories we map. It will start out with the data preprocessing and setup then move into the model creation, fine tuning and evaluation.

## Setup
This section aims to connect this notebook to a GPU and to our Google Storage Bucket

### Checking the location of the Colab notebook
We can guess the Google Cloud region on which the Colab notebook is running using gcping.

In order to get around GCS network egress charges, it may be a good idea to “Factory reset” the Colab runtime until we land in a zone that is on the same continent as our GCS bucket.

NOTE: Bucket crisis-nlp-needs is in us-central1 (i.e. the region should be Iowa)

In [1]:
# Check to make sure that th
!curl ipinfo.io

{
  "ip": "35.227.174.152",
  "hostname": "152.174.227.35.bc.googleusercontent.com",
  "city": "The Dalles",
  "region": "Oregon",
  "country": "US",
  "loc": "45.5946,-121.1787",
  "org": "AS15169 Google LLC",
  "postal": "97058",
  "timezone": "America/Los_Angeles",
  "readme": "https://ipinfo.io/missingauth"
}

### Mounting To Drive
This section can be used to connect to your Google Drive if you need to access files there later on

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import sys

# Fill in the Google Drive path where you uploaded the assignment
FOLDER_PATH = 'Colab Notebooks/Initium/T5 Medical Search/' #@param { type: "string"}
GOOGLE_DRIVE_PATH = os.path.join('drive', 'My Drive', FOLDER_PATH)
sys.path.append(GOOGLE_DRIVE_PATH)
print(os.listdir(GOOGLE_DRIVE_PATH))

### Connecting To Storage Bucket
We are going to connect to the google cloud storage bucket we have created by using Google Fuse

In [1]:
# Connecting to Google Cloud
import os

BASE_DIR = "medical-search" #@param { type: "string" }
if not BASE_DIR or BASE_DIR == "gs://":
  raise ValueError("You must enter a BASE_DIR.")
DATA_DIR = os.path.join(BASE_DIR, "data")
MODELS_DIR = os.path.join(BASE_DIR, "models")
ON_CLOUD = True

if ON_CLOUD:
  print("Setting up GCS access...")
  from google.colab import auth
  auth.authenticate_user()
  print('Connected!')

Setting up GCS access...
Connected!


In [2]:
# Download Google Fuse
!echo "deb http://packages.cloud.google.com/apt gcsfuse-bionic main" > /etc/apt/sources.list.d/gcsfuse.list
!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add -
!apt -qq update
!apt -qq install gcsfuse

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0100  2537  100  2537    0     0  68567      0 --:--:-- --:--:-- --:--:-- 68567
OK
48 packages can be upgraded. Run 'apt list --upgradable' to see them.
The following NEW packages will be installed:
  gcsfuse
0 upgraded, 1 newly installed, 0 to remove and 48 not upgraded.
Need to get 10.8 MB of archives.
After this operation, 23.1 MB of additional disk space will be used.
Selecting previously unselected package gcsfuse.
(Reading database ... 160772 files and directories currently installed.)
Preparing to unpack .../gcsfuse_0.35.1_amd64.deb ...
Unpacking gcsfuse (0.35.1) ...
Setting up gcsfuse (0.35.1) ...


In [3]:
# Mount a local drive that is connected to our storage bucket
# storage folder will be /folderOnColab/...
!mkdir folderOnColab
!gcsfuse --implicit-dirs --limit-bytes-per-sec -1 --limit-ops-per-sec -1 $BASE_DIR folderOnColab

2021/06/09 19:46:08.201884 Using mount point: /content/folderOnColab
2021/06/09 19:46:08.208978 Opening GCS connection...
2021/06/09 19:46:08.465566 Mounting file system "medical-search"...
2021/06/09 19:46:08.502110 File system has been successfully mounted.


### Packages

In [4]:
# Only if running on Colab or for the first time
!pip install transformers
!pip install sentencepiece
!pip install wandb

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/d5/43/cfe4ee779bbd6a678ac6a97c5a5cdeb03c35f9eaebbb9720b036680f9a2d/transformers-4.6.1-py3-none-any.whl (2.2MB)
[K     |████████████████████████████████| 2.3MB 2.9MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)
[K     |████████████████████████████████| 901kB 34.1MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/d4/e2/df3543e8ffdab68f5acc73f613de9c2b155ac47f162e725dcac87c521c11/tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 36.2MB/s 
Collecting huggingface-hub==0.0.8
  Downloading https://files.pythonhosted.org/packages/a1/88/7b1e45720ecf59c6c6737ff332f41c955963090a18e72

In [5]:
# Initial packages/functions needed and set of stopwords
import nltk
import pickle
import re
import sentencepiece
import string
import sys
import torch
import wandb
import pandas as pd
import numpy as np
from nltk.corpus import stopwords
from sklearn.metrics import f1_score, precision_recall_fscore_support, jaccard_score
from torch import nn, optim
from torch.utils.data import Dataset
from transformers import T5ForConditionalGeneration, T5Tokenizer
from tqdm import tqdm
nltk.download('stopwords')
nltk.download('punkt')
stop_words = set(stopwords.words('english'))

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [6]:
# Make sure we have a GPU
print('Checking running on correct device')
if torch.cuda.is_available():
  print('Good to go!')
else:
  print('Please set GPU via Edit -> Notebook Settings.')

Checking running on correct device
Good to go!


# Data Preparation

In this section we want to format our data so that it can easily be used for our T5 setup. We also want to split the data into train/val/test.

## Creating Input/Output

In [None]:
# LOCAL HOST
# Read in dataframes for clean brand, labels and priamry/secondary brand
df = pd.read_csv('/home/Resources/datasets/medical-search/Gloves/full_gloves_gudid_with_brand_corrected.tsv', sep='\t')
brand_df = pd.read_pickle('/home/Resources/datasets/medical-search/Gloves/clean_brand_diff_gloves_gudid_V3.pkl').drop_duplicates(subset='_brand_name_diff')
label_df = pd.read_pickle('/home/Resources/datasets/medical-search/Gloves/gudid_gloves_labelled_V4.pkl').reset_index()

In [10]:
# HOSTED RUNTIME USING GOOGLE FUSE
# Read in dataframes for clean brand, labels and priamry/secondary brand
df = pd.read_csv('folderOnColab/Data/full_gloves_gudid_with_brand_corrected.tsv', sep='\t')
brand_df = pd.read_pickle('folderOnColab/Data/clean_brand_diff_gloves_gudid_V3.pkl').drop_duplicates(subset='_brand_name_diff')
label_df = pd.read_pickle('folderOnColab/Data/gudid_gloves_labelled_V4.pkl').reset_index()

In [11]:
# Fix dfs so they are joinable and contain only relevant info
label_df = label_df[['primary_di', 'labels', 'brand_labels', 'ProductID']]
df.primary_di = df.primary_di.apply(lambda x: x.zfill(14) if x[0].isdigit() else x)

# Join dataframes
df = df.merge(label_df, on='primary_di', how='left')
df.head()

Unnamed: 0,primary_di,product_code,product_code_name,brand_name,version_model_number,catalog_number,duns_number,company_name,device_count,device_description,device_sterile,device_description_color_raw,_company_name,_brand_name,_brand_name_diff,primary_brand,secondary_brand,primary_di_str,duns_number_str,labels,brand_labels,ProductID
0,787081126866,['LZA'],['Polymer Patient Examination Glove'],Hand Armor,GDNPF105XX,GDNPF105XX,90125555,TRANZONIC ACQUISITION CORP.,100,NITRILE POWDER FREE EXAM GLOVES BLACK 6 MIL 10...,,[],Tranzonic Acquisition,Hand Armor,Hand Armor,Hand Armor,,787081126866,90125555,"{'color': 'black', 'package': ('100 pairs', '1...","{'color': '', 'package': ('', ''), 'size': (''...",
1,787081126859,['LZA'],['Polymer Patient Examination Glove'],Hand Armor,GDNPF105X,GDNPF105X,90125555,TRANZONIC ACQUISITION CORP.,100,NITRILE POWDER FREE EXAM GLOVES BLACK 6 MIL 10...,,[],Tranzonic Acquisition,Hand Armor,Hand Armor,Hand Armor,,787081126859,90125555,"{'color': 'black', 'package': ('100 pairs', '1...","{'color': '', 'package': ('', ''), 'size': (''...",
2,787081126842,['LZA'],['Polymer Patient Examination Glove'],Hand Armor,GDNPF105L,GDNPF105L,90125555,TRANZONIC ACQUISITION CORP.,100,NITRILE POWDER FREE EXAM GLOVES BLACK 6 MIL 10...,,[],Tranzonic Acquisition,Hand Armor,Hand Armor,Hand Armor,,787081126842,90125555,"{'color': 'black', 'package': ('100 pairs', '1...","{'color': '', 'package': ('', ''), 'size': (''...",
3,787081126835,['LZA'],['Polymer Patient Examination Glove'],Hand Armor,GDNPF105M,GDNPF105M,90125555,TRANZONIC ACQUISITION CORP.,100,NITRILE POWDER FREE EXAM GLOVES BLACK 6 MIL 10...,,[],Tranzonic Acquisition,Hand Armor,Hand Armor,Hand Armor,,787081126835,90125555,"{'color': 'black', 'package': ('100 pairs', '1...","{'color': '', 'package': ('', ''), 'size': (''...",
4,787081126828,['LZA'],['Polymer Patient Examination Glove'],Hand Armor,GDNPF105S,GDNPF105S,90125555,TRANZONIC ACQUISITION CORP.,100,NITRILE POWDER FREE EXAM GLOVES BLACK 6 MIL 10...,,[],Tranzonic Acquisition,Hand Armor,Hand Armor,Hand Armor,,787081126828,90125555,"{'color': 'black', 'package': ('100 pairs', '1...","{'color': '', 'package': ('', ''), 'size': (''...",


In [12]:
# Fix branding so that it doesn't cause null training inputs
df.secondary_brand = ['None' if pd.isnull(x) else x for x in df.secondary_brand]
df.primary_brand = ['None' if pd.isnull(x) else x for x in df.primary_brand]
df.brand_name = ['' if pd.isnull(x) else x for x in df.brand_name]

In [13]:
def output_label(labels: dict, brand_labels: dict, company: str, brand: str,
                 target: str = 'Query', format: str = 'Label') -> str:
    '''
    Input:
        labels: Dictionary with all labels
        company: Name of manufacturer
        brand: Tuple (primary_brand, secondary_brand)
        target: Str that is either "Query", "Description", or "Single Description" for the three different
                training methods we have used
        format: Str that indicates if we want the original T5 Training format
                or the Label format ie color: none, size: medium ...
                Options are Original and Label
    Output:
        Dictionary of labels converted to a string output that fits the rules needed
        for either Query or Description training and then also the format rules for 
        either Original or Label formatting
        Ex: color: xyz, material:abc, size: None, ... OR
            <extra_id_0> xyz <extra_id_1> abc <extra_id_2> ...
    '''
    # Color
    if labels['color'] == '':
        if brand_labels['color'] == '':
            color = 'None'
        else:
            color = brand_labels['color']
    else:
        color = labels['color']
    
    # Size
    if labels['size'][1] == '' and labels['size'][0] == '':
        size = ''
    elif labels['size'][1] == '':
        size = re.findall(r'[0-9]+\.?[0-9]*', labels['size'][0])[0]
    elif labels['size'][0] == '':
        size = labels['size'][1] 
    else:
        size = labels['size'][1] + ', ' + re.findall(r'[0-9]+\.?[0-9]*', labels['size'][0])[0]
    if size == '':
        if isinstance(brand_labels['size'], tuple):
            if brand_labels['size'][1] != '' and brand_labels['size'][0] != '':
                size = brand_labels['size'][1] + ', ' + re.findall(r'[0-9]+\.?[0-9]*', brand_labels['size'][0])[0]
            elif brand_labels['size'][1] != '':
                size = brand_labels['size'][1]
            elif brand_labels['size'][0] != '':
                size = re.findall(r'[0-9]+\.?[0-9]*', brand_labels['size'][0])[0]
            else:
                size = 'None'
        else:
            size = 'None'

    # Packaging
    # Boxes
    if labels['package'][0] == '':
        if isinstance(brand_labels['package'], tuple) and brand_labels['package'][0] == '':
            box = 'None'
        elif isinstance(brand_labels['package'], tuple):
            box = re.findall(r'[0-9]+\.?[0-9]*', brand_labels['package'][0])[0] + ' pairs'
        else:
            box = 'None'
    else:
        box = re.findall(r'[0-9]+\.?[0-9]*', labels['package'][0])[0] + ' pairs'

    # Cases
    if labels['package'][1] == '':
        if isinstance(brand_labels['package'], tuple) and brand_labels['package'][1] == '':
            case = 'None'
        elif isinstance(brand_labels['package'], tuple):
            case = re.findall(r'[0-9]+\.?[0-9]*', brand_labels['package'][1])[0] + ' boxes'
        else:
            case = 'None'
    else:
        case = re.findall(r'[0-9]+\.?[0-9]*', labels['package'][1])[0] + ' boxes'

    # Types
    if labels['type1'] == '':
        if brand_labels['type1'] == '':
            t1 = 'None'
        else:
            t1 = brand_labels['type1']
    else:
        t1 = labels['type1']
    
    if labels['type2'] == '':
        if brand_labels['type2'] == '':
            t2 = 'None'
        else:
            t2 = brand_labels['type2']
    else:
        t2 = labels['type2']
    
    # Use
    if len(labels['use']) == 0:
        if isinstance(brand_labels['use'], list) and len(brand_labels['use']) != 0:
            use = ', '.join(brand_labels['use'])
        else:
            use = 'None'
    else:
        use = ', '.join(labels['use'])
    
    # Glove Thickness
    if labels['thickness'] == '':
        if brand_labels['thickness'] == '':
            thickness = 'None'
        else:
            thickness = brand_labels['thickness']
    else:
        thickness = labels['thickness']
    
    if labels['length'] == '':
        if brand_labels['length'] == '':
            length = 'None'
        else:
            length = brand_labels['length']
    else:
        length = labels['length']

    # Different items/formats needed for Query vs Description
    if target == 'Query':
        # Material
        mat = labels['material']
        if len(mat) == 0:
            if isinstance(brand_labels['material'], list):
                mat = brand_labels['material']
                if len(mat) == 0:
                    material = 'None'
                else:
                    material = ', '.join(mat)
            else:
                material = 'None'
        else:
            material = ', '.join(mat)

        if labels['type3'] == '':
            if brand_labels.get('type3', '') == '':
                t3 = 'None'
            else:
                t3 = brand_labels['type3']
        else:
            t3 = labels['type3']
        
        c1 = (company != 'None' and not pd.isnull(company) and company != '')
        b0 = (brand[0] != 'None' and not pd.isnull(brand[0]) and brand[0] != '')
        b1 = (brand[1] != 'None' and not pd.isnull(brand[1]) and brand[0] != '')
        company = re.sub(',', '', company)
        brand0 = re.sub(',', '', brand[0])
        brand1 = re.sub(',', '', brand[1])
        if c1 and b0 and b1:
            brand = f'{company}, {brand0}, {brand1}'
        elif c1 and b0:
            brand = f'{company}, {brand0}'
        elif c1 and b1:
            brand = f'{company}, {brand1}'
        elif b0 and b1:
            brand = f'{brand0}, {brand1}'
        elif c1:
            brand = company
        elif b0:
            brand = brand0
        elif b1:
            brand = brand1
        else:
            brand = 'None'
        

    
        if format == 'Label':
            output = f'color: {color} | size: {size} | type1: {t1} | type2: {t2} | type3: {t3} | ' \
                    + f'boxes: {box} | cases: {case} | use: {use} | material: {material} | ' \
                    + f'thickness: {thickness} | length: {length} | brand: {brand}'
        else:
            output = f'<extra_id_0> {color} <extra_id_1> {size} <extra_id_2> {t1} <extra_id_3> {t2}' \
                    + f' <extra_id_4> {t3} <extra_id_5> {box} <extra_id_6> {case} <extra_id_7> {use}' \
                    + f' <extra_id_8> {material} <extra_id_9> {thickness} <extra_id_10> {length}' \
                    + f' <extra_id_11> {brand} <extra_id_12>'

    elif target == 'Description':
        # Material
        mat = [i for i in labels['material'] if i not in {'latex-free', 'synthetic', 'polymer'}]
        if len(mat) == 0:
            if isinstance(brand_labels['material'], list):
                mat = [i for i in brand_labels['material'] if i not in {'latex-free', 'synthetic', 'polymer'}]
                if len(mat) == 0:
                    material = 'None'
                else:
                    material = ', '.join(mat)
            else:
                material = 'None'
        else:
            material = ', '.join(mat)
        
        company = re.sub(',', '', company)
        brand0 = re.sub(',', '', brand[0])
        brand1 = re.sub(',', '', brand[1])
    
        if format == 'Label':
            output = f'color: {color} | size: {size} | type1: {t1} | type2: {t2} | boxes: {box} | ' \
                    + f'cases: {case} | use: {use} | material: {material} | thickness: {thickness} | ' \
                    + f'length: {length} | primary_brand: {brand0} | secondary_brand: {brand1} | company: {company}'
        else:
            output = f'<extra_id_0> {color} <extra_id_1> {size} <extra_id_2> {t1} <extra_id_3> {t2}' \
                    + f' <extra_id_4> {box} <extra_id_5> {case} <extra_id_6> {use}' \
                    + f' <extra_id_7> {material} <extra_id_8> {thickness} <extra_id_9> {length}' \
                    + f' <extra_id_10> {brand0} <extra_id_11> {brand1} <extra_id_12> {company} <extra_id_13>'
    
    elif target == 'Single Description':
        # Material
        mat = labels['material']
        if len(mat) == 0:
            if isinstance(brand_labels['material'], list):
                mat = brand_labels['material']
                if len(mat) == 0:
                    material = 'None'
                else:
                    material = ', '.join(mat)
            else:
                material = 'None'
        else:
            material = ', '.join(mat)

        if labels['type3'] == '':
            if brand_labels.get('type3', '') == '':
                t3 = 'None'
            else:
                t3 = brand_labels['type3']
        else:
            t3 = labels['type3']
        
        b0 = (brand[0] != 'None' and not pd.isnull(brand[0]) and brand[0] != '')
        b1 = (brand[1] != 'None' and not pd.isnull(brand[1]) and brand[0] != '')
        company = re.sub(',', '', company)
        brand0 = re.sub(',', '', brand[0])
        brand1 = re.sub(',', '', brand[1])
        
        if b0 and b1:
            brand = f'{brand0}, {brand1}'
        elif b0:
            brand = brand0
        elif b1:
            brand = brand1
        else:
            brand = 'None'
        

    
        if format == 'Label':
            output = f'color: {color} | size: {size} | type1: {t1} | type2: {t2} | type3: {t3} | ' \
                    + f'boxes: {box} | cases: {case} | use: {use} | material: {material} | ' \
                    + f'thickness: {thickness} | length: {length} | brand: {brand} | company: {company}'
        else:
            output = f'<extra_id_0> {color} <extra_id_1> {size} <extra_id_2> {t1} <extra_id_3> {t2}' \
                    + f' <extra_id_4> {t3} <extra_id_5> {box} <extra_id_6> {case} <extra_id_7> {use}' \
                    + f' <extra_id_8> {material} <extra_id_9> {thickness} <extra_id_10> {length}' \
                    + f' <extra_id_11> {brand} <extra_id_12> {company} <extra_id_13>'

    
    return 'targets: ' + output.lower()

In [14]:
np.random.seed(601)
def input_text_brand_placement(device_description: str, brand_name: str, company_name: str) -> str:
    '''
    Returns the correct input string for the T5. We want 
    branding to be in front of description 70% of time and at the end the rest
    '''
    if np.random.uniform() >= 0.7:
        return str(device_description) + ' ' + str(brand_name) + ' ' + str(company_name)
    else:
        return str(brand_name) + ' ' + str(company_name) + ' ' + str(device_description)

### T5 Setup For Description
This section creates the actual input text and target text from the dataframe that we have been creating above for a Description task training. If you want to create a mixture model setup, then you will need to skip to the next subsection




In [15]:
# Create input of brand name, device description, company name
df['input_text'] = df.apply(lambda x: input_text_brand_placement(x.device_description, x.brand_name, x.company_name), axis=1)
df.input_text = [re.sub(r'(?<![a-zA-Z])-(?=[a-zA-Z])',' - ',' '.join(nltk.word_tokenize(str(x).lower()))) for x in df.input_text] # Preprocessing

df = df[df.input_text != 'nan'].copy().reset_index() # Drop any nan descriptions
df.input_text = 'description: ' + df.input_text


# Target text as specified from output_label function
# Uncomment this line for label type target
#df['target_text'] = df.apply(lambda x: output_label(x.labels,x.brand_labels, x.company_name, (x.primary_brand, x.secondary_brand)), axis = 1)

# Uncomment this section one for original type target text
df['target_text'] = df.apply(lambda x: output_label(x.labels,x.brand_labels, 
                                                              x.company_name, (x.primary_brand, x.secondary_brand), 'Single Description', format='Original'), axis = 1)
input_prompts = '''The color of this item is <extra_id_0>. The size is <extra_id_1>. This item's sterile status is <extra_id_2>. ''' \
            + '''The powdered status is <extra_id_3>. Its latex-free status is <extra_id_4>. It comes in <extra_id_5> per box. ''' \
            + '''There are <extra_id_6> per case. It is used for <extra_id_7>. The material of this item is <extra_id_8>. This item ''' \
            + '''is <extra_id_9> thick. It is <extra_id_10> long. The product name is <extra_id_11>. <extra_id_12> manufactures this.'''
df.input_text = df.input_text + '. ' + input_prompts

df.head()

Unnamed: 0,index,primary_di,product_code,product_code_name,brand_name,version_model_number,catalog_number,duns_number,company_name,device_count,device_description,device_sterile,device_description_color_raw,_company_name,_brand_name,_brand_name_diff,primary_brand,secondary_brand,primary_di_str,duns_number_str,labels,brand_labels,ProductID,input_text,target_text
0,0,787081126866,['LZA'],['Polymer Patient Examination Glove'],Hand Armor,GDNPF105XX,GDNPF105XX,90125555,TRANZONIC ACQUISITION CORP.,100,NITRILE POWDER FREE EXAM GLOVES BLACK 6 MIL 10...,,[],Tranzonic Acquisition,Hand Armor,Hand Armor,Hand Armor,,787081126866,90125555,"{'color': 'black', 'package': ('100 pairs', '1...","{'color': '', 'package': ('', ''), 'size': (''...",,description: hand armor tranzonic acquisition ...,targets: <extra_id_0> black <extra_id_1> extra...
1,1,787081126859,['LZA'],['Polymer Patient Examination Glove'],Hand Armor,GDNPF105X,GDNPF105X,90125555,TRANZONIC ACQUISITION CORP.,100,NITRILE POWDER FREE EXAM GLOVES BLACK 6 MIL 10...,,[],Tranzonic Acquisition,Hand Armor,Hand Armor,Hand Armor,,787081126859,90125555,"{'color': 'black', 'package': ('100 pairs', '1...","{'color': '', 'package': ('', ''), 'size': (''...",,description: nitrile powder free exam gloves b...,targets: <extra_id_0> black <extra_id_1> extra...
2,2,787081126842,['LZA'],['Polymer Patient Examination Glove'],Hand Armor,GDNPF105L,GDNPF105L,90125555,TRANZONIC ACQUISITION CORP.,100,NITRILE POWDER FREE EXAM GLOVES BLACK 6 MIL 10...,,[],Tranzonic Acquisition,Hand Armor,Hand Armor,Hand Armor,,787081126842,90125555,"{'color': 'black', 'package': ('100 pairs', '1...","{'color': '', 'package': ('', ''), 'size': (''...",,description: hand armor tranzonic acquisition ...,targets: <extra_id_0> black <extra_id_1> large...
3,3,787081126835,['LZA'],['Polymer Patient Examination Glove'],Hand Armor,GDNPF105M,GDNPF105M,90125555,TRANZONIC ACQUISITION CORP.,100,NITRILE POWDER FREE EXAM GLOVES BLACK 6 MIL 10...,,[],Tranzonic Acquisition,Hand Armor,Hand Armor,Hand Armor,,787081126835,90125555,"{'color': 'black', 'package': ('100 pairs', '1...","{'color': '', 'package': ('', ''), 'size': (''...",,description: nitrile powder free exam gloves b...,targets: <extra_id_0> black <extra_id_1> mediu...
4,4,787081126828,['LZA'],['Polymer Patient Examination Glove'],Hand Armor,GDNPF105S,GDNPF105S,90125555,TRANZONIC ACQUISITION CORP.,100,NITRILE POWDER FREE EXAM GLOVES BLACK 6 MIL 10...,,[],Tranzonic Acquisition,Hand Armor,Hand Armor,Hand Armor,,787081126828,90125555,"{'color': 'black', 'package': ('100 pairs', '1...","{'color': '', 'package': ('', ''), 'size': (''...",,description: hand armor tranzonic acquisition ...,targets: <extra_id_0> black <extra_id_1> small...


### T5 Setup For Both Query And Description
This section should be run instead of "T5 Setup For Description" if the task is to create a mixture model. It will create an input for both our Query Task and our Description Task.

The Query Task is designed to be used for processing search queries that we may see through a website search bar. The Description Task would be used when helping process item descriptions on the backend and can be less detailed in some aspects because we would have time for backend processing.

For this setup, we're also going to change where the branding is displayed (in front or behind the description) as we believe that will help train a more robust model



### Create Label T5 Training Format
The label format was created as a QA setup in which T5 would read the input of description, branding, and company and automatically fill in the labels for us. The target of this format looks as follows:

```
Input:
[BRANDING] [COMPANY] [DEVICE_DESCRIPTION]

Output:
targets: color: xyz | size: abc | type1: efg, ...
```

In [None]:
# Create input text, preprocess, and drop any unwanted inputs
df['input_text'] = df.apply(lambda x: input_text_brand_placement(x.device_description, x.brand_name, x.company_name), axis=1)
df.input_text = [re.sub(r'(?<![a-zA-Z])-(?=[a-zA-Z])',' - ',' '.join(nltk.word_tokenize(str(x).lower()))) for x in df.input_text]
df_desc = df[df.input_text != 'nan'].copy().drop_duplicates(subset='input_text').reset_index(drop=True)
df_query = df_desc.copy()

# Description Task
df_desc['type'] = 'Description'
df_desc.input_text = 'description: ' + df_desc.input_text
df_desc['target_text'] = df_desc.apply(lambda x: output_label(x.labels,x.brand_labels, 
                                                              x.company_name, (x.primary_brand, x.secondary_brand), 'Description'), axis = 1)

# Query Task
df_query['type'] = 'Query'
df_query.input_text = 'query: ' + df_query.input_text
df_query['target_text'] = df_query.apply(lambda x: output_label(x.labels,x.brand_labels, x.company_name, (x.primary_brand, x.secondary_brand)), axis = 1)


#### Create Original T5 Training Format
If you want to create the original T5 trianing format then you will want to run this cell instead of the one above. The key difference between the two is how the model is trained. The Orginal model is formatted closer to how T5 was originally trained. This format will take the input text and add a part saying that each label corresponds to an id token. The target goal is to then fill in that token. An example of this can be seen below:

```
Input:
[BRANDING] [COMPANY] [DEVICE_DESCRIPTION] The color is <extra_id_0>. The size is <extra_id_1> ...

Output:
targets: <extra_id_0> abc <extra_id_1> xyz <extra_id_2> ...

```

In [None]:
# Place branding in front or back, preprocess, and delete unnecessary rows
df['input_text'] = df.apply(lambda x: input_text_brand_placement(x.device_description, x.brand_name, x.company_name), axis=1)
df.input_text = [re.sub(r'(?<![a-zA-Z])-(?=[a-zA-Z])',' - ',' '.join(nltk.word_tokenize(str(x).lower()))) for x in df.input_text]
df_desc = df[df.input_text != 'nan'].copy().drop_duplicates(subset='input_text').reset_index(drop=True)
df_query = df_desc.copy()

# Description Task
df_desc['type'] = 'Description'
df_desc.input_text = 'description: ' + df_desc.input_text
df_desc['target_text'] = df_desc.apply(lambda x: output_label(x.labels,x.brand_labels, 
                                                              x.company_name, (x.primary_brand, x.secondary_brand), 'Description', format='Original'), axis = 1)

# Query Task
df_query['type'] = 'Query'
df_query.input_text = 'query: ' + df_query.input_text
df_query['target_text'] = df_query.apply(lambda x: output_label(x.labels,x.brand_labels, x.company_name, (x.primary_brand, x.secondary_brand), format='Original'), axis = 1)


# Add in input prompts
extra_ids_query = 'The color is <extra_id_0> and size is <extra_id_1> and type1 is <extra_id_2> ' \
                  + ' and type2 is <extra_id_3> and type3 is <extra_id_4> and pairs per box is <extra_id_5> and boxes per case is <extra_id_6>' \
                  + ' and use is <extra_id_7> and material is <extra_id_8> and thickness is <extra_id_9> and length is <extra_id_10> and brand is <extra_id_11>'

extra_ids_desc = 'The color is <extra_id_0> and size is <extra_id_1> and type1 is <extra_id_2> ' \
                  + ' and type2 is <extra_id_3> and pairs per box is <extra_id_4> and boxes per case is <extra_id_5>' \
                  + ' and use is <extra_id_6> and material is <extra_id_7> and thickness is <extra_id_8> and length is <extra_id_9>' \
                  + ' and primary brand is <extra_id_10> and secondary brand is <extra_id_11> and company name is <extra_id_12>'

df_query.input_text = df_query.input_text + '. ' + extra_ids_query
df_desc.input_text = df_desc.input_text + '. ' + extra_ids_desc

## Splitting Into Train/Test/Val
We will split the data into roughly
- Train: 75% 
- Val: 13% 
- Test: 12% 

### Splitting For Description
This section will split the description task data into the train/test/val splits. Again, if you are creating the mixture model, go to the Query and Description section

In [16]:
np.random.seed(601)
df = df.drop_duplicates(subset='input_text').reset_index(drop=True)
temp_df = df[pd.isnull(df.ProductID)]
test_df = df[~ pd.isnull(df.ProductID)].reset_index(drop=True).copy()
msk = np.random.rand(len(temp_df)) < 0.85
train_df = temp_df[msk].reset_index(drop=True).copy()
val_df = temp_df[~ msk].reset_index(drop=True).copy()

In [9]:
# Run cell if using the Original format. Need to fix the target label function
train_df.target_text = [re.sub('^targets:\s*', '', x) for x in train_df.target_text]
val_df.target_text = [re.sub('^targets:\s*', '', x) for x in val_df.target_text]
test_df.target_text = [re.sub('^targets:\s*', '', x) for x in test_df.target_text]

In [13]:
# If no sections below are needed we can save
train_df.to_pickle('folderOnColab/Data/training_set_single_task_original_format.pkl')
test_df.to_pickle('folderOnColab/Data/test_set_single_task_original_format.pkl')
val_df.to_pickle('folderOnColab/Data/validation_set_single_task_original_format.pkl')

In [14]:
train_df = pd.read_pickle('folderOnColab/Data/training_set_single_task_original_format.pkl')
test_df = pd.read_pickle('folderOnColab/Data/test_set_single_task_original_format.pkl')
val_df = pd.read_pickle('folderOnColab/Data/validation_set_single_task_original_format.pkl')

#### Adding Regular and Lower Case
If you want to train the model on both regular and lowercase input then you will need to have the df in regular cased text and then run this section below

In [None]:
#### IF YOU WANT BOTH NORMAL AND LOWER CASE ######
def both_cases_dataset(df: pd.DataFrame) -> pd.DataFrame:
    '''
    Takes in a dataframe with regular casing text and doubles it with
    input_text being both regular case and lowercase
    '''
    temp = df.copy()
    temp.input_text = [str(x).lower() for x in temp.input_text]
    return pd.concat([df, temp], axis=0).reset_index(drop=True)

In [None]:
train_df = both_cases_dataset(train_df)
test_df = both_cases_dataset(test_df)
val_df = both_cases_dataset(val_df)

### Splitting For Both Query And Description
This section splits the data if you are training a mixture model. 

In [None]:
np.random.seed(601)

temp_df_query = df_query[pd.isnull(df_query.ProductID)]
temp_df_desc = df_desc[pd.isnull(df_desc.ProductID)]

test_df_query = df_query[~ pd.isnull(df_query.ProductID)].reset_index(drop=True).copy()
test_df_desc = df_desc[~ pd.isnull(df_desc.ProductID)].reset_index(drop=True).copy()

msk_q = np.random.rand(len(temp_df_query)) < 0.85
#msk_d = np.random.rand(len(temp_df_desc)) < 0.85

train_df_query = temp_df_query[msk_q].reset_index(drop=True).copy()
train_df_desc = temp_df_desc[msk_q].reset_index(drop=True).copy()

val_df_query = temp_df_query[~ msk_q].reset_index(drop=True).copy()
val_df_desc = temp_df_desc[~ msk_q].reset_index(drop=True).copy()

train_df = pd.concat([train_df_query, train_df_desc], axis=0).reset_index(drop=True)
test_df = pd.concat([test_df_query, test_df_desc], axis=0).reset_index(drop=True)
val_df = pd.concat([val_df_query, val_df_desc], axis=0).reset_index(drop=True)

In [None]:
train_df.to_pickle('/home/Resources/datasets/medical-search/Gloves/training_set_multiple_tasks_original_format.pkl')
test_df.to_pickle('/home/Resources/datasets/medical-search/Gloves/test_set_multiple_tasks_original_format.pkl')
val_df.to_pickle('/home/Resources/datasets/medical-search/Gloves/validation_set_multiple_tasks_original_format.pkl')

In [7]:
# Hosted Runtime read in datasets for training/eval
train_df = pd.read_pickle('folderOnColab/Data/training_set_single_task_original_format.pkl')
test_df = pd.read_pickle('folderOnColab/Data/test_set_single_task_original_format.pkl')
val_df = pd.read_pickle('folderOnColab/Data/validation_set_single_task_original_format.pkl')

In [None]:
# Local Runtime read in datasets for training/eval
train_df = pd.read_pickle('/home/Resources/datasets/medical-search/Gloves/training_set_multiple_tasks_original_format.pkl')
test_df = pd.read_pickle('/home/Resources/datasets/medical-search/Gloves/test_set_multiple_tasks_original_format.pkl')
val_df = pd.read_pickle('/home/Resources/datasets/medical-search/Gloves/validation_set_multiple_tasks_original_format.pkl')

# T5 Model and Dataset

## Dataset

In [8]:
class T5MedSearchDataset(Dataset):
    '''
    Wrapper that holds the data we need to train our T5 Model
    '''
    def __init__(self, df: pd.DataFrame, device: torch.device, dataset_type: str = 'Train',
                 task: str = 'Single', input_length: int = 256, output_length: int = 128):
        '''
        df: A dataframe with column names input_text and target_text
        dataset_type: Type of dataset out of choices Train, Validation, and Test
        device: Device to send tensors to
        task: String of number of tasks in dataset. Either "Single" or "Multiple". 
                If Multiple then dataset returns task type as well
        '''
        assert(dataset_type in ['Train', 'Validation', 'Test'])
        self.tokenizer = T5Tokenizer.from_pretrained('t5-base')
        self.input_text = df.input_text
        self.target_text = df.target_text
        self.df = df
        self.dataset_type = dataset_type
        self.device = device
        self.task = task
        self.input_length = input_length
        self.output_length = output_length
    def get_labels(self) -> pd.Series:
        return self.target_text
    def __len__(self):
        assert(len(self.input_text) == len(self.target_text))
        return len(self.target_text)
    def __getitem__(self, idx) -> dict:
        inputs = self.input_text[idx]
        targets = self.target_text[idx]
        dctInput = self.tokenizer(inputs, max_length=self.input_length,
                                  padding='max_length', return_tensors='pt')
        dctOutput = self.tokenizer(targets, max_length=self.output_length,
                                  padding='max_length', return_tensors='pt')
        if self.task == 'Multiple':
            return_dict = {
                'input_ids': dctInput.input_ids,
                'attention_mask': dctInput.attention_mask,
                'labels': dctOutput.input_ids,
                'targets': targets,
                'type': self.df.at[idx, 'type']
            }
        else:
            return_dict = {
                'input_ids': dctInput.input_ids,
                'attention_mask': dctInput.attention_mask,
                'labels': dctOutput.input_ids,
                'targets': targets
            }
        return return_dict

## Model Training/Eval Functions

### Evaluation Metrics
The V2 eval metrics allow for a better evaluation in which order matters for the target label phrases and also it can evaluate the Orginal Label Format with the extra id tokens

In [5]:
def comparing_labels(dict1: dict, dict2: dict) -> dict:
    '''
    Input:
        dict1: Dictionary containing the true labels
        dict2: Dictionary containing the predicted labels

        **Uses the keys of the target dict only**
    Output:
        Dictionary containing old dicts keys and values of 1 if the
        values matched and 0 if they dont
    '''

    dct = dict()
    for k,v in zip(dict1.keys(), dict1.values()):
        # Make sure that all elements for label are the same
        dct[k] = 1. if set(v) == set(dict2.get(k, [])) else 0.
    return dct

def answer_to_dict(ans: str) -> dict:
    '''
    Input:
        ans: String of the T5 output or target label
    Output:
        A dictionary with all of the label classes as keys and the
        respective label as values
    '''
    labels = [i.strip() for i in re.sub(r'^targets:\s', '', ans).split(',') if i != '' and i != ' ']
    dct = dict()
    #print(labels)
    for i in labels:
        lbl = str(i).split(':')
        #print(lbl)
        try:
            dct[lbl[0]] = lbl[1].split()
        except:
            continue
    return dct

In [13]:
def comparing_labels_V2(dict1: dict, dict2: dict, n_a: bool = False) -> dict:
    '''
    Input:
        dict1: Dictionary containing the true labels
        dict2: Dictionary containing the predicted labels
        n_a: Bool that indicates if we want to include results of a label prediction if the target is none.
            If true we won't include none

        **Uses the keys of the target dict only**
    Output:
        Dictionary containing old dicts keys and values of 1 if the
        values matched and 0 if they dont. If n_a is true then the value will
        be np.nan if BOTH target and label are none, 0 if the label is wrong,
        and 1 if the predicted label is correct
    '''

    #assert(set(dict1.keys()) == set(dict2.keys()))
    dct = dict()
    for k,v in zip(dict1.keys(), dict1.values()):
        # Make sure that all elements for label are the same
        if n_a:
            if v == ['none']:
                if dict2.get(k, []) == ['none']:
                    dct[k] = np.nan
                else:
                    dct[k] = 0.
            else:
                dct[k] = 1. if set(v) == set(dict2.get(k, [])) else 0.
        else:
            dct[k] = 1. if set(v) == set(dict2.get(k, [])) else 0.
    return dct

def answer_to_dict_v2(ans: str, format: str = 'Label', id_to_label: dict = None, na = True) -> dict:
    '''
    Input:
        ans: String of the T5 output or target label
        format: str of 'Original' or 'Label' that tells the function how the input will look
        id_to_label: If format equals Original then a id_to_label must be provided 
                        to change id number to label name
        na: Boolean. If True then we include the keys with value predicted as none. If False
            then we remove them
    Output:
        A dictionary with all of the label classes as keys and the
        respective label as values
    '''
    if format == 'Label':
        labels = [i.strip() for i in re.sub(r'^targets:\s', '', ans).split('|') if i != '' and i != ' ']
        dct = dict()
        for i in labels:
            lbl = str(i).split(':')
            try:
                dct[lbl[0]] = [j.strip() for j in lbl[1].split(',')]
            except:
                continue
        return dct
    elif format == 'Original':
        if id_to_label == None:
            raise ValueError("id_to_number can't be None when format is set to Original")

        labels = [i.strip() for i in re.sub('^targets:\s', '', ans).split('<extra_id_') if i != '' and i != ' ']
        dct = dict()
        for i in labels:
            lbl = re.split('>', str(i), maxsplit=1)
            try:
                if not na and lbl[1].strip() == 'none':
                    continue
                dct[id_to_label.get(lbl[0], None)] = [j.strip() for j in lbl[1].split(',')]
            except:
                continue
        return dct
    else:
        raise ValueError('Incorrect value given for the format argument')

### Training and Validation Functions

In [10]:
def cleanup_T5_tokenization(lst: list, input: bool = False) -> list:
    '''
    Takes a list of predicted spans and cleans up each string so that they 
    match the output ie gets rid of padding and unnecessary tokens
    '''
    for i,v in enumerate(lst):
        tmp = re.sub('^<pad> ', '', v)
        tmp = re.sub('<pad>', '', tmp)
        tmp = re.sub('</s>', '', tmp)
        tmp = re.sub('<', ' <', tmp).strip()
        if not input:
            # Gets rid of last "extra" <extra_id_> token that target requires
            tmp = re.sub('\<extra_id_[0-9]+\>$', '', tmp)
        lst[i] = tmp.strip()
    return lst 

In [11]:
def validate_model(T5: T5ForConditionalGeneration, data_loader: torch.utils.data.DataLoader, 
                   label_format: str, id_to_label: dict, wand: bool = False, verbose: bool = False) -> None:
    '''
    Input:
        T5: T5 model being validated
        data_loader: Pytorch DataLoader
        label_format: Str tells us if the targets are Original or Label format
        id_to_label: dictionary with id numbers as keys and label as value. Must be not null when 
                    lable format is set to Original
        wand: Teller for if we should log with wandb
        verbose: Bool. If true prints Inputs, Predicted, and Targets
    Output:
        No output returned but it will print and send results
        to wandb
    '''
    
    print('Validating Model')
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = T5Tokenizer.from_pretrained('t5-base')
    val_loss = []
    results = []
    with torch.no_grad():
        for ind, val in tqdm(enumerate(data_loader)):
            input_ids = val["input_ids"].to(device).squeeze(1)
            attention_mask = val["attention_mask"].to(device).squeeze(1)
            labels = val['labels'].to(device).squeeze(1)
            targets = val['targets']

            loss = T5.forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
            val_loss.append(loss.item())

            generated_ids = T5.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=100).squeeze()

            if label_format == 'Label':
                predicted_span = tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

                desc = [tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[i], skip_special_tokens=True)) for i in range(input_ids.shape[0])]  

                for i, v in enumerate(predicted_span):
                    if verbose:
                        print(f'Input: {desc[i]}')
                        print(f'Predicted: {v}')
                        print(f'Target: {targets[i]}', '\n')
                    pred_dict = answer_to_dict_v2(v)
                    targ_dict = answer_to_dict_v2(targets[i])
                    results.append(comparing_labels_V2(targ_dict, pred_dict))
            else:
                predicted_span = tokenizer.batch_decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
                predicted_span = cleanup_T5_tokenization(predicted_span, input=False)
                targets = [re.sub('\<extra_id_[0-9]+\>$', '', t).strip() for t in targets]

                desc = [tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[i], skip_special_tokens=False)) for i in range(input_ids.shape[0])]
                desc = cleanup_T5_tokenization(desc, input=True)

                for i, v in enumerate(predicted_span):
                    if verbose:
                        print(f'Input: {desc[i]}')
                        print(f'Predicted: {v}')
                        print(f'Target: {targets[i]}', '\n')
                    pred_dict = answer_to_dict_v2(v, 'Original', id_to_label)
                    targ_dict = answer_to_dict_v2(targets[i], 'Original', id_to_label)
                    results.append(comparing_labels_V2(targ_dict, pred_dict, n_a=False))
    
    
    
    results_df = pd.DataFrame(results)
    print(f'Validation Loss: {np.mean(val_loss)}')
    if wand:
        wandb.log({'Validation Loss': np.mean(val_loss)})
    
    for i in results_df.columns:
        print(f'Accuracy of {i} is: {np.nanmean(results_df[i])}')
        if wand:
            wandb.log({f'{str(i)} accuracy': np.nanmean(results_df[i])})

    return


In [None]:
####   TEMPORARY CELL USED TO TEST VALIDATION FUNCTION ######
T5 = T5ForConditionalGeneration.from_pretrained('t5-base', return_dict=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
T5.to(device)
val_data = T5MedSearchDataset(val_df, 'Validation', input_length=512)
valloader = torch.utils.data.DataLoader(val_data, batch_size=10, shuffle=True)
id_to_label = {'0': 'color', '1': 'size', '2': 'type1', '3': 'type2', '4': 'type3', '5': 'boxes', '6': 'case', 
                     '7': 'use', '8': 'material', '9': 'thickness', '10': 'length', '11': 'brand', '12': 'company'}
validate_model(T5, valloader, 'Original', id_to_label)

In [25]:
def train_T5(train_df: pd.DataFrame, model:T5ForConditionalGeneration, val_df: pd.DataFrame = None,
             epochs: int = 1, batch_size: int = 2, lr: float = 1e-5, log: int = 20,
             wand: bool = False, training_type: bool = False, input_length: int = 256,
             label_format: str = 'Original', id_to_label: dict = None, val_verbose: bool = False) -> T5ForConditionalGeneration:
    '''
    This function aims to train a T5 model on our medical-search labelling task. It relies on wandb

    Input:
        train_df: Pandas df of our training set
        val_df: Optional df used for validation
        pretrain: String containing path to pretrained model
        epochs: Number of epochs to train the model
        batch_size: Number of values in each batch
        lr: learning rate
        log: Number of batches to run through before recording key metrics
        wand: Teller for if we should log with wandb
        training_type: Bool that tells if we are training on multiple tasks or not
        input_length: Int max length of input for model
        label_format: Str that tells us if our labels are original or label format
        id_to_label: Dict if label_format is set to original. Has keys of the ints of the extra_id numbers
                    and values of the corresponding label name
        val_verbose: Bool if true then prints Inputs, Predicted, and Targets
    Output:
        Trained T5 model
    '''
    # Load model and tokenizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print('Loading model and tokenizer')
    T5 = model
    T5.to(device)


    # Create the data loaders for training and validation if needed
    print('Creating datasets')
    if training_type:
        train_data = T5MedSearchDataset(train_df, 'Train', task='Multiple', input_length=input_length)
    else:
        train_data = T5MedSearchDataset(train_df, 'Train', input_length=input_length)
    trainloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

    if val_df is not None:
        val_data = T5MedSearchDataset(val_df, 'Validation', input_length=input_length)
        valloader = torch.utils.data.DataLoader(val_data, batch_size=10, shuffle=True)
    
    # Set up the model's optimizer
    optimizer = optim.Adam(T5.parameters(), lr=lr)

    # Train the model and get the loss per iteration
    print('Beginning Training')
    for e in range(epochs):
        print(f"Training epoch {e + 1}")
        train_loss = []
        temp_loss = []
        for ind, val in tqdm(enumerate(trainloader)):
            input_ids = val["input_ids"].to(device).squeeze(1)
            attention_mask = val["attention_mask"].to(device).squeeze(1)
            labels = val['labels'].to(device).squeeze(1)
            optimizer.zero_grad()
            loss = T5.forward(input_ids=input_ids, attention_mask=attention_mask, labels=labels).loss
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
            temp_loss.append(loss.item())
            if ind % log == 0:
                if wand:
                    wandb.log({'Train loss': np.mean(temp_loss)})
                print(f'Train loss: {np.mean(temp_loss)}')
                temp_loss = []
                if val_df is not None and ind != 0:
                    validate_model(T5, valloader, label_format, id_to_label, wand, val_verbose)
        print(f'The average loss for this epoch was {np.mean(train_loss)}')
    
    print('Finished Training!')
    return T5

# Perplexity Scoring
In hopes of not having to train different models to find out results, we are going to test the perplexity of the inputs to see which is the lowest. This should then indicate which format will train the best model

In [None]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
T5 = T5ForConditionalGeneration.from_pretrained('t5-base', return_dict=True)
T5.to(device)

In [17]:
desc = 'description: hand armor tranzonic acquisition corp. nitrile sterile powder free latex free exam gloves black 6 mil 10/100 - xx-large 12" length. '
add_on = """The shade of this item is <extra_id_0>. The fit is <extra_id_1>. The type1 is <extra_id_2>. Its powdered status is <extra_id_3>. The latex-free status of this item is <extra_id_4>. This item comes in <extra_id_5> per box. Each item has <extra_id_6> per case. This item is used for <extra_id_7>. This item is made out of <extra_id_8>. It is <extra_id_9> thick. This item is <extra_id_10> long. The brand family of this item is <extra_id_11>. The company who produces this item is <extra_id_12>."""
#target = '<extra_id_0> tranzonic acquisition corp. <extra_id_1>'
#good_add_on = "The color of this item is <extra_id_0>. The size is <extra_id_1>. This item's sterile status is <extra_id_2>. The powdered status is <extra_id_3>. Its latex-free status is <extra_id_4>. It comes in <extra_id_5> per box. There are <extra_id_6> per case. It is used for <extra_id_7>. The material of this item is <extra_id_8>. This item is <extra_id_9> thick. It is <extra_id_10> long. The product name is <extra_id_11>. <extra_id_12> manufactures this"
#bad_add_on = "Its color is <extra_id_0>. Its size is <extra_id_1>. The sterile status of this item is <extra_id_2>. The item's powdered status is <extra_id_3>. The item's latex-free status is <extra_id_4>. The pairs per box is <extra_id_5>. The boxes per case is <extra_id_6>. The use is <extra_id_7>. The material is <extra_id_8>. Thickness is <extra_id_9>. The length is <extra_id_10>. The branding for this item is <extra_id_11>. The company is <extra_id_12>"
target = 'targets: <extra_id_0> black <extra_id_1> extra-extra-large <extra_id_2> sterile <extra_id_3> powder-free <extra_id_4> latex-free <extra_id_5> 100 pairs <extra_id_6> 10 boxes <extra_id_7> exam <extra_id_8> nitrile <extra_id_9> 6mil <extra_id_10> 12in <extra_id_11> hand armor <extra_id_12> tranzonic acquisition corp. <extra_id_13>'

#desc = 'description: uline wht 3 mil nitrile glvs - xl uline , inc. white gloves size x-large The color is <extra_id_0>. The size is <extra_id_1>. The type1 is <extra_id_2>. The type2 is <extra_id_3>. The pairs per box is <extra_id_4>. The number of boxes per case is <extra_id_5>. The use is <extra_id_6>. The material is <extra_id_7>. The thickness is <extra_id_8>. The length is <extra_id_9>. The brand is <extra_id_10>, and the variation is <extra_id_11>. The company is <extra_id_12>'
#desc_target = 'targets: <extra_id_0> white <extra_id_1> extra-large <extra_id_2> none <extra_id_3> none <extra_id_4> none <extra_id_5> none <extra_id_6> none <extra_id_7> nitrile <extra_id_8> 3mil <extra_id_9> none <extra_id_10> none <extra_id_11> none <extra_id_12> uline inc. <extra_id_13>'

In [18]:
inp = desc + add_on
print(inp)
dct = tokenizer(inp, max_length=512, padding='max_length', return_tensors='pt')
desc_target = tokenizer(target, return_tensors='pt')['input_ids']
tmp = T5.forward(input_ids=dct['input_ids'].to(device), attention_mask=dct['attention_mask'].to(device), labels=desc_target.to(device))
desc_target.cpu(), dct['input_ids'].cpu(), dct['attention_mask'].cpu()
print('')

description: hand armor tranzonic acquisition corp. nitrile sterile powder free latex free exam gloves black 6 mil 10/100 - xx-large 12" length. The shade of this item is <extra_id_0>. The fit is <extra_id_1>. The type1 is <extra_id_2>. Its powdered status is <extra_id_3>. The latex-free status of this item is <extra_id_4>. This item comes in <extra_id_5> per box. Each item has <extra_id_6> per case. This item is used for <extra_id_7>. This item is made out of <extra_id_8>. It is <extra_id_9> thick. This item is <extra_id_10> long. The brand family of this item is <extra_id_11>. The company who produces this item is <extra_id_12>.



In [19]:
tmp.loss

tensor(2.8458, device='cuda:0', grad_fn=<NllLossBackward>)

# Training
This section trains the T5 with the data we processed earlier

In [26]:
# Start a new run
wandb.init(project='medical-search-gloves', entity='spencervagg')

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [27]:
# Base model
model = T5ForConditionalGeneration.from_pretrained('t5-base', return_dict=True)

# Log gradients and model parameters
wandb.watch(model)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1199.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=891691430.0, style=ProgressStyle(descri…




[<wandb.wandb_torch.TorchGraph at 0x7ff9d04f7390>]

In [28]:
id_to_label = {'0': 'color', '1': 'size', '2': 'type1', '3': 'type2', '4': 'type3', '5': 'boxes', '6': 'case', 
                     '7': 'use', '8': 'material', '9': 'thickness', '10': 'length', '11': 'brand', '12': 'company'}

In [29]:
T5 = train_T5(train_df, model, val_df=val_df, batch_size=10, epochs=5, lr=1e-4, log=100, training_type=False, input_length=512,
              label_format='Original', id_to_label=id_to_label, val_verbose=False, wand=True)

Loading model and tokenizer
Creating datasets


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=791656.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1389353.0, style=ProgressStyle(descript…




0it [00:00, ?it/s]

Beginning Training
Training epoch 1


1it [00:01,  1.85s/it]

Train loss: 14.005777359008789


100it [02:30,  1.54s/it]

Train loss: 0.4819963195174932
Validating Model



0it [00:00, ?it/s][A
1it [00:02,  2.09s/it][A
2it [00:04,  2.10s/it][A
3it [00:06,  2.10s/it][A
4it [00:08,  2.11s/it][A
5it [00:10,  2.07s/it][A
6it [00:12,  2.10s/it][A
7it [00:14,  2.09s/it][A
8it [00:16,  2.10s/it][A
9it [00:18,  2.08s/it][A
10it [00:20,  2.08s/it][A
11it [00:23,  2.11s/it][A
12it [00:25,  2.09s/it][A
13it [00:27,  2.08s/it][A
14it [00:29,  2.05s/it][A
15it [00:31,  2.04s/it][A
16it [00:33,  2.04s/it][A
17it [00:35,  2.04s/it][A
18it [00:37,  2.05s/it][A
19it [00:39,  2.04s/it][A
20it [00:41,  2.05s/it][A
21it [00:43,  2.06s/it][A
22it [00:45,  2.06s/it][A
23it [00:47,  2.10s/it][A
24it [00:49,  2.09s/it][A
25it [00:51,  2.09s/it][A
26it [00:54,  2.11s/it][A
27it [00:56,  2.07s/it][A
28it [00:58,  2.07s/it][A
29it [01:00,  2.05s/it][A
30it [01:02,  2.05s/it][A
31it [01:04,  2.06s/it][A
32it [01:06,  2.08s/it][A
33it [01:08,  2.08s/it][A
34it [01:10,  2.10s/it][A
35it [01:11,  2.06s/it]
101it [03:45, 23.59s/it]

Validation Loss: 0.021547212158995015
Accuracy of color is: 0.9507246376811594
Accuracy of size is: 0.9014492753623189
Accuracy of type1 is: 0.9768115942028985
Accuracy of type2 is: 0.9884057971014493
Accuracy of type3 is: 0.9681159420289855
Accuracy of boxes is: 0.9420289855072463
Accuracy of case is: 0.9855072463768116
Accuracy of use is: 0.9478260869565217
Accuracy of material is: 0.936231884057971
Accuracy of thickness is: 0.9855072463768116
Accuracy of length is: 0.9826086956521739
Accuracy of brand is: 0.8260869565217391
Accuracy of company is: 0.9304347826086956


200it [06:18,  1.56s/it]

Train loss: 0.01497873971471563
Validating Model



0it [00:00, ?it/s][A
1it [00:02,  2.20s/it][A
2it [00:04,  2.20s/it][A
3it [00:06,  2.15s/it][A
4it [00:08,  2.13s/it][A
5it [00:10,  2.11s/it][A
6it [00:12,  2.10s/it][A
7it [00:14,  2.07s/it][A
8it [00:16,  2.06s/it][A
9it [00:18,  2.08s/it][A
10it [00:20,  2.10s/it][A
11it [00:23,  2.11s/it][A
12it [00:25,  2.12s/it][A
13it [00:27,  2.10s/it][A
14it [00:29,  2.11s/it][A
15it [00:31,  2.10s/it][A
16it [00:33,  2.11s/it][A
17it [00:35,  2.11s/it][A
18it [00:37,  2.10s/it][A
19it [00:39,  2.12s/it][A
20it [00:42,  2.10s/it][A
21it [00:44,  2.11s/it][A
22it [00:46,  2.12s/it][A
23it [00:48,  2.12s/it][A
24it [00:50,  2.11s/it][A
25it [00:52,  2.09s/it][A
26it [00:54,  2.10s/it][A
27it [00:56,  2.08s/it][A
28it [00:58,  2.08s/it][A
29it [01:00,  2.10s/it][A
30it [01:03,  2.10s/it][A
31it [01:05,  2.10s/it][A
32it [01:07,  2.11s/it][A
33it [01:09,  2.13s/it][A
34it [01:11,  2.12s/it][A
35it [01:12,  2.08s/it]
201it [07:34, 23.96s/it]

Validation Loss: 0.00847356187711869
Accuracy of color is: 0.9855072463768116
Accuracy of size is: 0.927536231884058
Accuracy of type1 is: 0.9855072463768116
Accuracy of type2 is: 0.991304347826087
Accuracy of type3 is: 0.9884057971014493
Accuracy of boxes is: 0.991304347826087
Accuracy of case is: 0.9971014492753624
Accuracy of use is: 0.9884057971014493
Accuracy of material is: 0.9565217391304348
Accuracy of thickness is: 0.9942028985507246
Accuracy of length is: 0.9971014492753624
Accuracy of brand is: 0.9565217391304348
Accuracy of company is: 0.9797101449275363


205it [07:40,  2.24s/it]
0it [00:00, ?it/s]

The average loss for this epoch was 0.31091660560313156
Training epoch 2


1it [00:01,  1.56s/it]

Train loss: 0.003323771758005023


100it [02:35,  1.55s/it]

Train loss: 0.0063193713303189725
Validating Model



0it [00:00, ?it/s][A
1it [00:02,  2.03s/it][A
2it [00:04,  2.04s/it][A
3it [00:06,  2.07s/it][A
4it [00:08,  2.11s/it][A
5it [00:10,  2.12s/it][A
6it [00:12,  2.10s/it][A
7it [00:14,  2.10s/it][A
8it [00:16,  2.10s/it][A
9it [00:18,  2.08s/it][A
10it [00:21,  2.10s/it][A
11it [00:23,  2.09s/it][A
12it [00:25,  2.08s/it][A
13it [00:27,  2.10s/it][A
14it [00:29,  2.13s/it][A
15it [00:31,  2.13s/it][A
16it [00:33,  2.13s/it][A
17it [00:35,  2.13s/it][A
18it [00:37,  2.13s/it][A
19it [00:40,  2.13s/it][A
20it [00:42,  2.16s/it][A
21it [00:44,  2.14s/it][A
22it [00:46,  2.12s/it][A
23it [00:48,  2.10s/it][A
24it [00:50,  2.08s/it][A
25it [00:52,  2.12s/it][A
26it [00:54,  2.10s/it][A
27it [00:56,  2.07s/it][A
28it [00:58,  2.08s/it][A
29it [01:01,  2.09s/it][A
30it [01:03,  2.10s/it][A
31it [01:05,  2.09s/it][A
32it [01:07,  2.09s/it][A
33it [01:09,  2.10s/it][A
34it [01:11,  2.13s/it][A
35it [01:13,  2.09s/it]
101it [03:51, 23.99s/it]

Validation Loss: 0.0058968197820442065
Accuracy of color is: 0.9855072463768116
Accuracy of size is: 0.9304347826086956
Accuracy of type1 is: 0.9942028985507246
Accuracy of type2 is: 0.9942028985507246
Accuracy of type3 is: 0.9739130434782609
Accuracy of boxes is: 0.991304347826087
Accuracy of case is: 0.9971014492753624
Accuracy of use is: 0.991304347826087
Accuracy of material is: 0.9739130434782609
Accuracy of thickness is: 0.9971014492753624
Accuracy of length is: 1.0
Accuracy of brand is: 0.9739130434782609
Accuracy of company is: 0.9797101449275363


200it [06:25,  1.55s/it]

Train loss: 0.0040639800723874945
Validating Model



0it [00:00, ?it/s][A
1it [00:02,  2.12s/it][A
2it [00:04,  2.12s/it][A
3it [00:06,  2.11s/it][A
4it [00:08,  2.12s/it][A
5it [00:10,  2.14s/it][A
6it [00:12,  2.15s/it][A
7it [00:14,  2.14s/it][A
8it [00:17,  2.17s/it][A
9it [00:19,  2.17s/it][A
10it [00:21,  2.18s/it][A
11it [00:23,  2.17s/it][A
12it [00:25,  2.14s/it][A
13it [00:27,  2.16s/it][A
14it [00:30,  2.16s/it][A
15it [00:32,  2.13s/it][A
16it [00:34,  2.12s/it][A
17it [00:36,  2.10s/it][A
18it [00:38,  2.11s/it][A
19it [00:40,  2.10s/it][A
20it [00:42,  2.11s/it][A
21it [00:44,  2.10s/it][A
22it [00:46,  2.11s/it][A
23it [00:49,  2.12s/it][A
24it [00:51,  2.12s/it][A
25it [00:53,  2.10s/it][A
26it [00:55,  2.11s/it][A
27it [00:57,  2.11s/it][A
28it [00:59,  2.10s/it][A
29it [01:01,  2.11s/it][A
30it [01:03,  2.10s/it][A
31it [01:05,  2.11s/it][A
32it [01:07,  2.10s/it][A
33it [01:10,  2.13s/it][A
34it [01:12,  2.12s/it][A
35it [01:13,  2.10s/it]
201it [07:42, 24.12s/it]

Validation Loss: 0.004189389292150736
Accuracy of color is: 0.9652173913043478
Accuracy of size is: 0.9681159420289855
Accuracy of type1 is: 0.9942028985507246
Accuracy of type2 is: 1.0
Accuracy of type3 is: 1.0
Accuracy of boxes is: 0.9884057971014493
Accuracy of case is: 0.9971014492753624
Accuracy of use is: 0.9971014492753624
Accuracy of material is: 0.9884057971014493
Accuracy of thickness is: 0.9971014492753624
Accuracy of length is: 1.0
Accuracy of brand is: 0.9942028985507246
Accuracy of company is: 0.9884057971014493


205it [07:47,  2.28s/it]
0it [00:00, ?it/s]

The average loss for this epoch was 0.005131898222706939
Training epoch 3


1it [00:01,  1.57s/it]

Train loss: 0.004152858164161444


100it [02:35,  1.55s/it]

Train loss: 0.0019060669612372295
Validating Model



0it [00:00, ?it/s][A
1it [00:02,  2.11s/it][A
2it [00:04,  2.13s/it][A
3it [00:06,  2.15s/it][A
4it [00:08,  2.17s/it][A
5it [00:10,  2.15s/it][A
6it [00:12,  2.13s/it][A
7it [00:14,  2.11s/it][A
8it [00:17,  2.09s/it][A
9it [00:19,  2.13s/it][A
10it [00:21,  2.13s/it][A
11it [00:23,  2.11s/it][A
12it [00:25,  2.13s/it][A
13it [00:27,  2.13s/it][A
14it [00:29,  2.13s/it][A
15it [00:31,  2.11s/it][A
16it [00:33,  2.08s/it][A
17it [00:35,  2.07s/it][A
18it [00:38,  2.11s/it][A
19it [00:40,  2.10s/it][A
20it [00:42,  2.10s/it][A
21it [00:44,  2.08s/it][A
22it [00:46,  2.08s/it][A
23it [00:48,  2.09s/it][A
24it [00:50,  2.08s/it][A
25it [00:52,  2.08s/it][A
26it [00:54,  2.08s/it][A
27it [00:56,  2.10s/it][A
28it [00:59,  2.12s/it][A
29it [01:01,  2.10s/it][A
30it [01:03,  2.08s/it][A
31it [01:05,  2.10s/it][A
32it [01:07,  2.12s/it][A
33it [01:09,  2.11s/it][A
34it [01:11,  2.12s/it][A
35it [01:13,  2.09s/it]
101it [03:52, 24.05s/it]

Validation Loss: 0.0026364322834914284
Accuracy of color is: 0.9768115942028985
Accuracy of size is: 0.9710144927536232
Accuracy of type1 is: 0.9971014492753624
Accuracy of type2 is: 1.0
Accuracy of type3 is: 1.0
Accuracy of boxes is: 0.9942028985507246
Accuracy of case is: 1.0
Accuracy of use is: 1.0
Accuracy of material is: 0.9971014492753624
Accuracy of thickness is: 0.9942028985507246
Accuracy of length is: 0.9971014492753624
Accuracy of brand is: 0.991304347826087
Accuracy of company is: 0.9942028985507246


200it [06:25,  1.55s/it]

Train loss: 0.0014659534316160715
Validating Model



0it [00:00, ?it/s][A
1it [00:02,  2.19s/it][A
2it [00:04,  2.18s/it][A
3it [00:06,  2.16s/it][A
4it [00:08,  2.13s/it][A
5it [00:10,  2.11s/it][A
6it [00:12,  2.10s/it][A
7it [00:14,  2.09s/it][A
8it [00:16,  2.07s/it][A
9it [00:18,  2.10s/it][A
10it [00:20,  2.08s/it][A
11it [00:22,  2.07s/it][A
12it [00:25,  2.07s/it][A
13it [00:27,  2.08s/it][A
14it [00:29,  2.06s/it][A
15it [00:31,  2.08s/it][A
16it [00:33,  2.09s/it][A
17it [00:35,  2.11s/it][A
18it [00:37,  2.11s/it][A
19it [00:39,  2.11s/it][A
20it [00:41,  2.11s/it][A
21it [00:43,  2.09s/it][A
22it [00:45,  2.07s/it][A
23it [00:48,  2.07s/it][A
24it [00:50,  2.07s/it][A
25it [00:52,  2.05s/it][A
26it [00:54,  2.07s/it][A
27it [00:56,  2.11s/it][A
28it [00:58,  2.12s/it][A
29it [01:00,  2.10s/it][A
30it [01:02,  2.09s/it][A
31it [01:04,  2.12s/it][A
32it [01:07,  2.13s/it][A
33it [01:09,  2.13s/it][A
34it [01:11,  2.12s/it][A
35it [01:12,  2.08s/it]
201it [07:41, 23.87s/it]

Validation Loss: 0.002528658858084652
Accuracy of color is: 0.9768115942028985
Accuracy of size is: 0.9884057971014493
Accuracy of type1 is: 0.9971014492753624
Accuracy of type2 is: 1.0
Accuracy of type3 is: 1.0
Accuracy of boxes is: 0.9971014492753624
Accuracy of case is: 1.0
Accuracy of use is: 1.0
Accuracy of material is: 0.991304347826087
Accuracy of thickness is: 0.9971014492753624
Accuracy of length is: 1.0
Accuracy of brand is: 0.9884057971014493
Accuracy of company is: 0.9942028985507246


205it [07:47,  2.28s/it]
0it [00:00, ?it/s]

The average loss for this epoch was 0.0017228186660071425
Training epoch 4


1it [00:01,  1.55s/it]

Train loss: 0.0003967223165091127


100it [02:35,  1.55s/it]

Train loss: 0.000618046035815496
Validating Model



0it [00:00, ?it/s][A
1it [00:02,  2.03s/it][A
2it [00:04,  2.06s/it][A
3it [00:06,  2.11s/it][A
4it [00:08,  2.09s/it][A
5it [00:10,  2.08s/it][A
6it [00:12,  2.09s/it][A
7it [00:14,  2.11s/it][A
8it [00:16,  2.12s/it][A
9it [00:18,  2.10s/it][A
10it [00:21,  2.11s/it][A
11it [00:23,  2.12s/it][A
12it [00:25,  2.10s/it][A
13it [00:27,  2.08s/it][A
14it [00:29,  2.07s/it][A
15it [00:31,  2.06s/it][A
16it [00:33,  2.10s/it][A
17it [00:35,  2.08s/it][A
18it [00:37,  2.10s/it][A
19it [00:40,  2.14s/it][A
20it [00:42,  2.12s/it][A
21it [00:44,  2.11s/it][A
22it [00:46,  2.11s/it][A
23it [00:48,  2.13s/it][A
24it [00:50,  2.10s/it][A
25it [00:52,  2.09s/it][A
26it [00:54,  2.10s/it][A
27it [00:56,  2.08s/it][A
28it [00:58,  2.07s/it][A
29it [01:00,  2.08s/it][A
30it [01:03,  2.11s/it][A
31it [01:05,  2.10s/it][A
32it [01:07,  2.09s/it][A
33it [01:09,  2.07s/it][A
34it [01:11,  2.09s/it][A
35it [01:12,  2.08s/it]
101it [03:51, 23.85s/it]

Validation Loss: 0.0016884544766591198
Accuracy of color is: 0.9884057971014493
Accuracy of size is: 0.9942028985507246
Accuracy of type1 is: 0.9971014492753624
Accuracy of type2 is: 1.0
Accuracy of type3 is: 1.0
Accuracy of boxes is: 0.9971014492753624
Accuracy of case is: 1.0
Accuracy of use is: 1.0
Accuracy of material is: 0.9971014492753624
Accuracy of thickness is: 0.9971014492753624
Accuracy of length is: 1.0
Accuracy of brand is: 0.9884057971014493
Accuracy of company is: 0.9971014492753624


200it [06:24,  1.55s/it]

Train loss: 0.0007788899805018446
Validating Model



0it [00:00, ?it/s][A
1it [00:02,  2.19s/it][A
2it [00:04,  2.18s/it][A
3it [00:06,  2.17s/it][A
4it [00:08,  2.14s/it][A
5it [00:10,  2.09s/it][A
6it [00:12,  2.07s/it][A
7it [00:14,  2.08s/it][A
8it [00:16,  2.06s/it][A
9it [00:18,  2.06s/it][A
10it [00:20,  2.06s/it][A
11it [00:22,  2.06s/it][A
12it [00:24,  2.07s/it][A
13it [00:26,  2.07s/it][A
14it [00:29,  2.10s/it][A
15it [00:31,  2.07s/it][A
16it [00:33,  2.08s/it][A
17it [00:35,  2.11s/it][A
18it [00:37,  2.10s/it][A
19it [00:39,  2.14s/it][A
20it [00:41,  2.13s/it][A
21it [00:43,  2.11s/it][A
22it [00:46,  2.11s/it][A
23it [00:48,  2.08s/it][A
24it [00:50,  2.08s/it][A
25it [00:52,  2.10s/it][A
26it [00:54,  2.08s/it][A
27it [00:56,  2.10s/it][A
28it [00:58,  2.08s/it][A
29it [01:00,  2.07s/it][A
30it [01:02,  2.07s/it][A
31it [01:04,  2.09s/it][A
32it [01:06,  2.09s/it][A
33it [01:08,  2.08s/it][A
34it [01:10,  2.08s/it][A
35it [01:12,  2.07s/it]
201it [07:40, 23.80s/it]

Validation Loss: 0.0019346099330245384
Accuracy of color is: 0.9942028985507246
Accuracy of size is: 0.9739130434782609
Accuracy of type1 is: 0.9971014492753624
Accuracy of type2 is: 1.0
Accuracy of type3 is: 1.0
Accuracy of boxes is: 1.0
Accuracy of case is: 1.0
Accuracy of use is: 1.0
Accuracy of material is: 0.9971014492753624
Accuracy of thickness is: 0.9971014492753624
Accuracy of length is: 1.0
Accuracy of brand is: 0.9826086956521739
Accuracy of company is: 0.9971014492753624


205it [07:45,  2.27s/it]
0it [00:00, ?it/s]

The average loss for this epoch was 0.0006923567382306451
Training epoch 5


1it [00:01,  1.56s/it]

Train loss: 0.00010433672287035733


100it [02:35,  1.55s/it]

Train loss: 0.0006052157381054713
Validating Model



0it [00:00, ?it/s][A
1it [00:02,  2.10s/it][A
2it [00:04,  2.09s/it][A
3it [00:06,  2.10s/it][A
4it [00:08,  2.10s/it][A
5it [00:10,  2.08s/it][A
6it [00:12,  2.08s/it][A
7it [00:14,  2.08s/it][A
8it [00:16,  2.07s/it][A
9it [00:18,  2.09s/it][A
10it [00:20,  2.07s/it][A
11it [00:22,  2.07s/it][A
12it [00:25,  2.10s/it][A
13it [00:27,  2.08s/it][A
14it [00:29,  2.07s/it][A
15it [00:31,  2.11s/it][A
16it [00:33,  2.10s/it][A
17it [00:35,  2.11s/it][A
18it [00:37,  2.10s/it][A
19it [00:39,  2.11s/it][A
20it [00:41,  2.10s/it][A
21it [00:43,  2.11s/it][A
22it [00:46,  2.10s/it][A
23it [00:48,  2.08s/it][A
24it [00:50,  2.07s/it][A
25it [00:52,  2.10s/it][A
26it [00:54,  2.07s/it][A
27it [00:56,  2.07s/it][A
28it [00:58,  2.04s/it][A
29it [01:00,  2.06s/it][A
30it [01:02,  2.07s/it][A
31it [01:04,  2.10s/it][A
32it [01:06,  2.08s/it][A
33it [01:08,  2.09s/it][A
34it [01:11,  2.12s/it][A
35it [01:12,  2.07s/it]
101it [03:50, 23.75s/it]

Validation Loss: 0.0017158355476567522
Accuracy of color is: 0.9942028985507246
Accuracy of size is: 1.0
Accuracy of type1 is: 0.9971014492753624
Accuracy of type2 is: 1.0
Accuracy of type3 is: 1.0
Accuracy of boxes is: 0.9971014492753624
Accuracy of case is: 1.0
Accuracy of use is: 1.0
Accuracy of material is: 0.9942028985507246
Accuracy of thickness is: 0.9971014492753624
Accuracy of length is: 1.0
Accuracy of brand is: 0.991304347826087
Accuracy of company is: 0.991304347826087


200it [06:24,  1.55s/it]

Train loss: 0.0005944426456699148
Validating Model



0it [00:00, ?it/s][A
1it [00:02,  2.06s/it][A
2it [00:04,  2.05s/it][A
3it [00:06,  2.05s/it][A
4it [00:08,  2.07s/it][A
5it [00:10,  2.07s/it][A
6it [00:12,  2.06s/it][A
7it [00:14,  2.09s/it][A
8it [00:16,  2.07s/it][A
9it [00:18,  2.07s/it][A
10it [00:20,  2.10s/it][A
11it [00:22,  2.09s/it][A
12it [00:24,  2.06s/it][A
13it [00:26,  2.05s/it][A
14it [00:28,  2.06s/it][A
15it [00:31,  2.10s/it][A
16it [00:33,  2.09s/it][A
17it [00:35,  2.11s/it][A
18it [00:37,  2.14s/it][A
19it [00:39,  2.11s/it][A
20it [00:41,  2.09s/it][A
21it [00:43,  2.08s/it][A
22it [00:45,  2.09s/it][A
23it [00:47,  2.08s/it][A
24it [00:50,  2.12s/it][A
25it [00:52,  2.12s/it][A
26it [00:54,  2.12s/it][A
27it [00:56,  2.13s/it][A
28it [00:58,  2.14s/it][A
29it [01:00,  2.13s/it][A
30it [01:02,  2.11s/it][A
31it [01:04,  2.09s/it][A
32it [01:06,  2.08s/it][A
33it [01:09,  2.09s/it][A
34it [01:11,  2.09s/it][A
35it [01:12,  2.07s/it]
201it [07:40, 23.81s/it]

Validation Loss: 0.0024058061236116503
Accuracy of color is: 0.9942028985507246
Accuracy of size is: 0.991304347826087
Accuracy of type1 is: 0.9942028985507246
Accuracy of type2 is: 1.0
Accuracy of type3 is: 0.9942028985507246
Accuracy of boxes is: 0.9971014492753624
Accuracy of case is: 1.0
Accuracy of use is: 1.0
Accuracy of material is: 0.9971014492753624
Accuracy of thickness is: 0.9971014492753624
Accuracy of length is: 1.0
Accuracy of brand is: 0.9942028985507246
Accuracy of company is: 0.9884057971014493


205it [07:45,  2.27s/it]

The average loss for this epoch was 0.0005885098946192757
Finished Training!





In [30]:
torch.save(T5.cpu().state_dict(), 'folderOnColab/Models/T5_Single_Original_Format_V1.bin')

# Model Evaluation
With the trained model, we can evaluate how well it did on the validation set (and Test set if needed)

## Single Task

In [None]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')
val_data = T5MedSearchDataset(val_df, 'Validation')
valloader = torch.utils.data.DataLoader(val_data, batch_size=2, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
test_data = T5MedSearchDataset(test_df, 'Test')
testloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
T5 = T5ForConditionalGeneration.from_pretrained('t5-base', return_dict=True)
#T5.load_state_dict(torch.load(GOOGLE_DRIVE_PATH + 'T5_V5.bin'))
T5.load_state_dict(torch.load('folderOnColab/Models/T5_V5_Case_Uncase_1e3.bin'))
T5.to(device)

### Validation

In [None]:
validation_output = []
for ind, val in enumerate(valloader):
    input_ids = val["input_ids"].to(device).squeeze(1)
    attention_mask = val["attention_mask"].to(device).squeeze(1)
    labels = val['labels']
    targets = val['targets']
    generated_ids = T5.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=100).squeeze()
    predicted_span = tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    desc = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids.squeeze(0), skip_special_tokens=True))
    validation_output.append({'description': desc, 'target_output': targets, 'predicted_output': predicted_span})
    print('Example ', ind)
    print('Description: ', desc)
    print('Target Text', targets)
    print('T5 Prediction', predicted_span, '\n')

In [None]:
val_output_df = pd.DataFrame(validation_output)
val_output_df.to_pickle('folderOnColab/Results/T5_V5_Val_Results_Case_Uncase.pkl')

In [None]:
results = []
for i, v in val_output_df.iterrows():
    pred_dict = answer_to_dict(v.predicted_output)
    targ_dict = answer_to_dict(v.target_output[0])
    results.append(comparing_labels(targ_dict, pred_dict))
results_val_df = pd.DataFrame(results)

In [None]:
results_val_df.mean(axis=0)

In [None]:
# Print out any mislabels
label = 'size'
for i, v in val_output_df[results_val_df[label].values == 0].iterrows():
    print(f'Index {i}')
    print(v.description)
    print('Target', answer_to_dict(v.target_output[0])[label])
    print('Pred', answer_to_dict(v.predicted_output)[label], '\n')

### Test

In [None]:
test_output = []
T5.to(device)
for ind, val in enumerate(testloader):
    input_ids = val["input_ids"].to(device).squeeze(1)
    attention_mask = val["attention_mask"].to(device).squeeze(1)
    labels = val['labels']
    targets = val['targets']
    generated_ids = T5.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=100).squeeze()
    predicted_span = tokenizer.decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    desc = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids.squeeze(0), skip_special_tokens=True))
    test_output.append({'description': desc, 'target_output': targets, 'predicted_output': predicted_span})
    print('Example ', ind)
    print('Description: ', desc)
    print('Target Text', targets)
    print('T5 Prediction', predicted_span, '\n')

In [None]:
test_output_df = pd.DataFrame(test_output)
test_output_df.to_pickle('folderOnColab/Results/T5_V5_Test_Results.pkl')

In [None]:
results = []
for i, v in test_output_df.iterrows():
    pred_dict = answer_to_dict(v.predicted_output)
    targ_dict = answer_to_dict(v.target_output[0])
    results.append(comparing_labels(targ_dict, pred_dict))
results_test_df = pd.DataFrame(results)

In [None]:
results_test_df.mean(axis=0)

In [None]:
# Print out any mislabels
label = 'company'
for i, v in test_output_df[results_test_df[label].values == 0].iterrows():
    print(f'Index {i}')
    print(v.description)
    print('Target', answer_to_dict(v.target_output[0])[label])
    try:
        print('Pred', answer_to_dict(v.predicted_output)[label], '\n')
    except:
        print(v.predicted_output, '\n')

## Multiple Task

### Label Format Evaluation
Evaluates Model with the Label format ie 
``` 
targets: color: none | size: medium | ...
```

In [None]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')
val_data = T5MedSearchDataset(val_df, 'Validation', task='Multiple')
valloader = torch.utils.data.DataLoader(val_data, batch_size=10, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#T5 = T5ForConditionalGeneration.from_pretrained('t5-base', return_dict=True)
#T5.load_state_dict(torch.load('folderOnColab/Models/T5_Multiple_V1.bin'))
T5.to(device)

In [None]:
validation_output = []
for ind, val in enumerate(valloader):
    # Validation Input
    input_ids = val["input_ids"].to(device).squeeze(1)
    attention_mask = val["attention_mask"].to(device).squeeze(1)
    labels = val['labels']
    targets = val['targets']

    # Generate model output
    generated_ids = T5.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=100).squeeze()
    predicted_span = tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    desc = [tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[i], skip_special_tokens=True)) for i in range(input_ids.shape[0])]

    # Add outputs to list of dfs
    tmp_df = pd.DataFrame([desc, targets, predicted_span, val['type']]).transpose()
    tmp_df.columns = ['input', 'target_output', 'predicted_output', 'task']
    validation_output.append(tmp_df)

    # Print outputs/targets
    print('Example ', ind)
    print('Description: ', desc)
    print('Target Text', targets)
    print('T5 Prediction', predicted_span, '\n')

In [None]:
val_output_df = pd.concat(validation_output)
val_output_df.to_pickle('folderOnColab/Results/T5_Multiple_V1_Val_Results.pkl')

In [None]:
val_output_df = pd.read_pickle('/home/Resources/datasets/medical-search/Gloves/T5_Multiple_V1_Val_Results.pkl') # Results Local Runtime
val_output_df.head()

In [None]:
results = []
task = 'Query'
for i, v in val_output_df[val_output_df.task == task].iterrows():
    pred_dict = answer_to_dict_v2(v.predicted_output)
    targ_dict = answer_to_dict_v2(v.target_output)
    results.append(comparing_labels_V2(targ_dict, pred_dict))
results_val_df = pd.DataFrame(results)

In [None]:
# This cell gets the accuracy for the tasks
print(f'Accuracy for {task} task with none labels')
for i in results_val_df.columns:
    print(f'Accuracy of {i} is: {np.nanmean(results_val_df[i])}')

Accuracy for Query task with none labels
Accuracy of color is: 0.9797101449275363
Accuracy of size is: 0.9884057971014493
Accuracy of type1 is: 0.9956521739130435
Accuracy of type2 is: 1.0
Accuracy of type3 is: 0.9942028985507246
Accuracy of boxes is: 0.991304347826087
Accuracy of cases is: 0.9869565217391304
Accuracy of use is: 0.991304347826087
Accuracy of material is: 0.991304347826087
Accuracy of thickness is: 0.981159420289855
Accuracy of length is: 1.0
Accuracy of brand is: 0.9768115942028985
Accuracy of primary_brand is: 0.8318840579710145
Accuracy of secondary_brand is: 0.9536231884057971
Accuracy of company is: 0.9884057971014493


In [None]:
# Print out any mislabels
# Run this without task being set to Query/Description so all mislabels are caught 
incorrect_labels = []
for col in results_val_df.columns:
    for i, v in val_output_df[results_val_df[col].values == 0].iterrows():
        print(f'Task {v.task}, Label {col}')
        print(v.input)
        targ = answer_to_dict_v2(v.target_output)[col]
        pred = answer_to_dict_v2(v.predicted_output)[col]
        print('Target', targ)
        print('Pred', pred, '\n')
        incorrect_labels.append({'task': v.task, 'incorrect_label': col, 'input': v.input, 'target_output': targ, 'predicted_output': pred})
incorrect_labels_df = pd.DataFrame(incorrect_labels)

In [None]:
incorrect_labels_df.to_pickle('/home/Resources/datasets/medical-search/Gloves/T5_Multiple_V1_Val_Incorrect.pkl')

### Original Format Evaluation
Evaluates Model with the Original format ie \<extra_id_0\> none \<extra_id_1\> medium ...

In [31]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')
val_data = T5MedSearchDataset(val_df, 'Validation', input_length=512)
valloader = torch.utils.data.DataLoader(val_data, batch_size=10, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#T5 = T5ForConditionalGeneration.from_pretrained('t5-base', return_dict=True)
#T5.load_state_dict(torch.load('folderOnColab/Models/T5_Multiple_Original_Format_V1.bin'))
T5.to(device)

In [33]:
validation_output = []
for ind, val in enumerate(valloader):
    input_ids = val["input_ids"].to(device).squeeze(1)
    attention_mask = val["attention_mask"].to(device).squeeze(1)
    labels = val['labels']
    targets = val['targets']


    generated_ids = T5.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=100).squeeze()
    predicted_span = tokenizer.batch_decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
    desc = [tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[i], skip_special_tokens=False)) for i in range(input_ids.shape[0])]
    desc = cleanup_T5_tokenization(desc, input=True)
    predicted_span = cleanup_T5_tokenization(predicted_span, input=False)
    targets = [re.sub('\<extra_id_[0-9]+\>$', '', t).strip() for t in targets]
    
    
    tmp_df = pd.DataFrame([desc, targets, predicted_span]).transpose()
    tmp_df.columns = ['input', 'target_output', 'predicted_output']
    validation_output.append(tmp_df)
    
    
    print('Example ', ind)
    print('Description: ', desc)
    print('Target Text', targets)
    print('T5 Prediction', predicted_span, '\n')


Example  0
Description:  ["description: nitrile powder free exam gloves black 6 mil 10/100 - ex-large hand armor tranzonic acquisition corp .. The color of this item is <extra_id_0> . The size is <extra_id_1> . This item's sterile status is <extra_id_2> . The powdered status is <extra_id_3> . Its latex-free status is <extra_id_4> . It comes in <extra_id_5> per box. There are <extra_id_6> per case. It is used for <extra_id_7> . The material of this item is <extra_id_8> . This item is <extra_id_9> thick. It is <extra_id_10> long. The product name is <extra_id_11> . <extra_id_12> manufactures this.", "description: nitrile powder free exam gloves black 6 mil 10/100 - medium hand armor tranzonic acquisition corp .. The color of this item is <extra_id_0> . The size is <extra_id_1> . This item's sterile status is <extra_id_2> . The powdered status is <extra_id_3> . Its latex-free status is <extra_id_4> . It comes in <extra_id_5> per box. There are <extra_id_6> per case. It is used for <extra_

In [34]:
val_output_df = pd.concat(validation_output)
val_output_df.to_pickle('folderOnColab/Results/T5_Single_Original_Format_V1_Val_Results.pkl')

In [35]:
val_output_df = pd.read_pickle('folderOnColab/Results/T5_Single_Original_Format_V1_Val_Results.pkl') # Results Local Runtime
val_output_df.head()

Unnamed: 0,input,target_output,predicted_output
0,description: nitrile powder free exam gloves b...,targets: <extra_id_0> black <extra_id_1> extra...,targets: <extra_id_0> black <extra_id_1> extra...
1,description: nitrile powder free exam gloves b...,targets: <extra_id_0> black <extra_id_1> mediu...,targets: <extra_id_0> black <extra_id_1> mediu...
2,description: ambitex glv exam non-sterile stre...,targets: <extra_id_0> cream <extra_id_1> extra...,targets: <extra_id_0> cream <extra_id_1> extra...
3,description: nitrile exam blue powder free glo...,targets: <extra_id_0> blue <extra_id_1> large ...,targets: <extra_id_0> blue <extra_id_1> large ...
4,description: ambitex glv exam non-sterile nitr...,targets: <extra_id_0> black <extra_id_1> extra...,targets: <extra_id_0> black <extra_id_1> extra...


In [None]:
# Dicts for extra_id number to label name
query_id_to_label = {'0': 'color', '1': 'size', '2': 'type1', '3': 'type2', '4': 'type3', '5': 'boxes', '6': 'case', 
                     '7': 'use', '8': 'material', '9': 'thickness', '10': 'length', '11': 'brand'}
desc_id_to_label = {'0': 'color', '1': 'size', '2': 'type1', '3': 'type2', '4': 'boxes', '5': 'case', 
                     '6': 'use', '7': 'material', '8': 'thickness', '9': 'length', '10': 'primary_brand',
                    '11': 'secondary_brand', '12': 'company'}

In [38]:
results = []
task = 'Query'
# [val_output_df.task == task]
for i, v in val_output_df.iterrows():
    pred_dict = answer_to_dict_v2(v.predicted_output,  'Original', id_to_label)
    targ_dict = answer_to_dict_v2(v.target_output,  'Original', id_to_label)
    results.append(comparing_labels_V2(targ_dict, pred_dict, n_a=True))
results_val_df = pd.DataFrame(results)

In [39]:
# This cell gets the accuracy for the tasks
print(f'Accuracy for {task} task without none labels Original Label Setup')
for i in results_val_df.columns:
    print(f'Accuracy of {i} is: {np.nanmean(results_val_df[i])}')

Accuracy for Query task without none labels Original Label Setup
Accuracy of color is: 0.9867549668874173
Accuracy of size is: 0.9901960784313726
Accuracy of type1 is: 0.9555555555555556
Accuracy of type2 is: 1.0
Accuracy of type3 is: 0.9733333333333334
Accuracy of boxes is: 0.9857142857142858
Accuracy of case is: 1.0
Accuracy of use is: 1.0
Accuracy of material is: 0.9958677685950413
Accuracy of thickness is: 0.9861111111111112
Accuracy of length is: 1.0
Accuracy of brand is: 0.9884169884169884
Accuracy of company is: 0.9884057971014493


# Medequip Query Evaluation

In [6]:
df = pd.read_pickle('/content/folderOnColab/Data/MedEquip_glove_search_history-auto-NL.pkl')
df.head()

Unnamed: 0,Search Term,Total Unique Searches,Gloves
33,gloves,44,True
135,nitrile gloves,26,True
1277,sterile gloves,7,True
1637,exam gloves,6,True
2310,latex gloves,5,True


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
T5 = T5ForConditionalGeneration.from_pretrained('t5-base', return_dict=True)
T5.load_state_dict(torch.load('folderOnColab/Models/T5_Single_Original_Format_V1.bin'))
T5.to(device)

In [12]:
id_to_label = {'0': 'color', '1': 'size', '2': 'type1', '3': 'type2', '4': 'type3', '5': 'boxes', '6': 'case', 
                     '7': 'use', '8': 'material', '9': 'thickness', '10': 'length', '11': 'brand', '12': 'company'}

input_prompts = '''The color of this item is <extra_id_0>. The size is <extra_id_1>. This item's sterile status is <extra_id_2>. ''' \
            + '''The powdered status is <extra_id_3>. Its latex-free status is <extra_id_4>. It comes in <extra_id_5> per box. ''' \
            + '''There are <extra_id_6> per case. It is used for <extra_id_7>. The material of this item is <extra_id_8>. This item ''' \
            + '''is <extra_id_9> thick. It is <extra_id_10> long. The product name is <extra_id_11>. <extra_id_12> manufactures this.'''

In [18]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')
results = []
for ind, val in df.iterrows():
    input = str(val['Search Term']).lower()
    input = 'description: ' + \
            re.sub(r'(?<![a-zA-Z])-(?=[a-zA-Z])',' - ',' '.join(nltk.word_tokenize(input))) \
             + '. ' + input_prompts

    dct = tokenizer(input, max_length=512, padding='max_length', return_tensors='pt')
    input_ids = dct["input_ids"].to(device).squeeze(1)
    attention_mask = dct["attention_mask"].to(device).squeeze(1)
    
    generated_ids = T5.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=100)#.squeeze()
    predicted_span = tokenizer.batch_decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)

    inference_tokens = cleanup_T5_tokenization(predicted_span, input=False)[0]
    inference_tokens = answer_to_dict_v2(inference_tokens, 'Original', id_to_label, na = False)
    
    
    tmp = {'search_term': val['Search Term'], 'input': input, 'labels': inference_tokens, 
           'total_unique_searches': val['Total Unique Searches'], 'gloves': val['Gloves']}
    results.append(tmp)
    print('Search Term: ', val['Search Term'])
    print('Model Input: ', input)
    print(predicted_span, '\n')
results_df = pd.DataFrame(results)
results_df.head()   

Search Term:  gloves
Model Input:  description: gloves. The color of this item is <extra_id_0>. The size is <extra_id_1>. This item's sterile status is <extra_id_2>. The powdered status is <extra_id_3>. Its latex-free status is <extra_id_4>. It comes in <extra_id_5> per box. There are <extra_id_6> per case. It is used for <extra_id_7>. The material of this item is <extra_id_8>. This item is <extra_id_9> thick. It is <extra_id_10> long. The product name is <extra_id_11>. <extra_id_12> manufactures this.
['targets: <extra_id_0> none <extra_id_1> none <extra_id_2> none <extra_id_3> none <extra_id_4> none <extra_id_5> none <extra_id_6> none <extra_id_7> none <extra_id_8> none <extra_id_9> none <extra_id_10> none <extra_id_11> none <extra_id_12> none'] 

Search Term:  nitrile gloves
Model Input:  description: nitrile gloves. The color of this item is <extra_id_0>. The size is <extra_id_1>. This item's sterile status is <extra_id_2>. The powdered status is <extra_id_3>. Its latex-free status

Unnamed: 0,search_term,input,labels,total_unique_searches,gloves
0,gloves,description: gloves. The color of this item is...,{},44,True
1,nitrile gloves,description: nitrile gloves. The color of this...,"{'material': ['nitrile'], 'company': ['target']}",26,True
2,sterile gloves,description: sterile gloves. The color of this...,"{'type1': ['sterile'], 'company': ['target']}",7,True
3,exam gloves,description: exam gloves. The color of this it...,"{'use': ['exam'], 'company': ['exam gloves']}",6,True
4,latex gloves,description: latex gloves. The color of this i...,{'type3': ['latex']},5,True


In [19]:
results_df.to_pickle('/content/folderOnColab/Results/medequip_search_results_single_original_format_v1.pkl')

# Check Out Medequip Results
This section tests the Medequip descriptions of the rows intersecting the Medequip and GUDID databases

In [None]:
df = pd.read_csv('folderOnColab/Data/full_gloves_gudid_with_brand_corrected.tsv', sep='\t')
brand_df = pd.read_pickle('folderOnColab/Data/clean_brand_diff_gloves_gudid_V3.pkl').drop_duplicates(subset='_brand_name_diff')
label_df = pd.read_pickle('folderOnColab/Data/gudid_gloves_labelled_V3.pkl').reset_index()

In [None]:
label_df = label_df[~pd.isnull(label_df.ProductID)]
label_df['input_text'] = [re.sub(r'(?<![a-zA-Z])-(?=[a-zA-Z])',' - ',' '.join(nltk.word_tokenize(str(x)))) for x in label_df.ItemDescription]
label_df.input_text = 'description: ' + label_df.input_text

In [None]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')
results = []
for ind, val in label_df.iterrows():
    input = str(val.input_text)
    dct = tokenizer(input, max_length=256, padding='max_length', return_tensors='pt')
    input_ids = dct["input_ids"].to(device).squeeze(1)
    attention_mask = dct["attention_mask"].to(device).squeeze(1)
    generated_ids = T5.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=100)#.squeeze()
    predicted_span = tokenizer.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    tmp = {'primary_di': val['primary_di'], 'input': input, 'labels': predicted_span, 
           'ProductID': val['ProductID'], 'device_description': val['device_description'],
           'brand_name': val['brand_name']}
    results.append(tmp)
    print('Description: ', val['ItemDescription'])
    print('Model Input: ', input)
    print(predicted_span, '\n')
results_df = pd.DataFrame(results)
results_df.head()   

In [None]:
results_df['ItemDescription'] = label_df.ItemDescription

In [None]:
results_df.to_pickle('/content/folderOnColab/Results/medequip_description_gudid_intersection_cased_results.pkl')

## Evaluate

In [None]:
label_df = label_df[['primary_di', 'labels', 'brand_labels', 'ProductID', 'ItemDescription', 'catalog_number']]
df.primary_di = df.primary_di.apply(lambda x: x.zfill(14) if x[0].isdigit() else x)

# Join dataframes
df = df.merge(label_df, on='primary_di', how='left')
df.head()

In [None]:
df.secondary_brand = ['None' if pd.isnull(x) else x for x in df.secondary_brand]
df.primary_brand = ['None' if pd.isnull(x) else x for x in df.primary_brand]
df.brand_name = ['' if pd.isnull(x) else x for x in df.brand_name]

In [None]:
temp_df = pd.read_pickle('/content/folderOnColab/Results/medequip_description_gudid_intersection_cased_results.pkl')
temp_df.head()

In [None]:
df = pd.merge(temp_df, df[['primary_di', 'company_name', 'primary_brand', 'secondary_brand']], on='primary_di')
df = df.rename(columns={'labels': 'predicted'})

In [None]:
df = pd.merge(df, label_df[['primary_di', 'catalog_number', 'labels', 'brand_labels']], on='primary_di')

In [None]:
results = []
for i, v in df.iterrows():
    targ_text = output_label(v.labels,v.brand_labels, v.company_name, (v.primary_brand, v.secondary_brand))
    pred_dict = answer_to_dict(v.predicted[0])
    targ_dict = answer_to_dict(targ_text)
    results.append(comparing_labels(targ_dict, pred_dict))
results_test_df = pd.DataFrame(results)

In [None]:
results_test_df.mean(axis=0)

color              0.547445
size               0.879562
type1              0.739051
type2              0.563869
use                0.784672
material           0.691606
thickness          0.952555
length             0.934307
primary_brand      0.206204
secondary_brand    0.928832
company            0.000000
dtype: float64

In [None]:
# Print out any mislabels
label = 'color'
for i, v in df[results_test_df[label].values == 0].iterrows():
    print(f'Index {i}')
    print(v.input)
    print('Target', v.labels[label])
    try:
        print('Pred', answer_to_dict(v.predicted)[label], '\n')
    except:
        print(v.predicted, '\n')

In [None]:
df.to_pickle('/content/folderOnColab/Results/medequip_description_gudid_intersection_cased_results_updated.pkl')