In [None]:
#@title Press ᐅ to run me!

import torch
if not torch.cuda.is_available():
    from IPython.display import Image
    display(Image(url='https://galileo-public-tutorial-data.s3.us-west-1.amazonaws.com/use_gpus_colab.gif', width=700))
    print("GPU not enabled! Please go to Runtime > Change runtime type > Hardware Accelerator > GPU")
    
    # Restart the runtime
    import os, time
    time.sleep(1) # gives the print statements time to flush
    os._exit(0) # exits without allowing the next cell to run

print("Installing `dataquality` and other libraries. This should take ~30 seconds.")


!pip install ipywidgets==7.7.1 traitlets==5.1.1 &> /dev/null
# Install dataquality without needing a restart of the kernel
!pip install vaex vaex-core==4.9.1 vaex-hdf5==0.12.2 wrapt==1.13.3 datasets transformers torchmetrics evaluate responses==0.18 dataquality==0.6.1 --no-deps &> /dev/null
!pip install diskcache==5.2.1 gorilla==0.3.0 resource==0.2.1 types-requests==2.25.2 transformers==4.22.2 multiprocess xxhash blake3==0.2.1 aplus frozendict!=2.2.0 nest-asyncio==1.4.0 rich jedi cachetools==5.2.0 pydantic==1.9.2 &> /dev/null

from IPython.display import clear_output
clear_output()

"""
============================================================
BEGIN MODEL TRAINING CODE [REMOVE ME IN THE FUTURE, and potentially replace pytorch code with HF for a lower amount of code overall needed]
============================================================
"""
import numpy as np
import random
import torch.nn.functional as F
import torchmetrics
from tqdm.notebook import tqdm
import transformers
from transformers import AutoTokenizer
from typing import List
import datasets
from datasets import load_dataset
import time

file_name = None

def train_model():
    transformers.logging.disable_progress_bar()
    transformers.logging.set_verbosity_error()
    from google.colab import output
    output.enable_custom_widget_manager()
    
    global df

    if "split" in df:
      train_df = df[df["split"]=="train"]
      test_df = df[df["split"]=="val"]
    else:
      train_df = df.sample(frac = 0.8)
      test_df = df.drop(train_df.index)

    # 🔭🌕 Initializing a new run in Galileo. Each run is part of a project.
    global file_name
    dq.init(task_type="text_classification", 
            project_name="galileo_in_5_minutes", #TODO: fix this
            run_name=f"{file_name.replace('/', '-').replace('.', '_')}")


    import torch
    class TextDataset(torch.utils.data.Dataset):
        def __init__(
            self, dataset: pd.DataFrame, split: str, list_of_labels: List[str] = None
        ):
            self.dataset = dataset

            # 🔭🌕 Logging the dataset with Galileo
            # Note: this works seamlessly because self.dataset has text, label, and
            # id columns. See `help(dq.log_dataset)` for more info
            dq.log_dataset(self.dataset, split=split)

            tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
            self.encodings = tokenizer(
                self.dataset["text"].tolist(), truncation=True, padding=True
            )

            self.list_of_labels = list_of_labels or self.dataset["label"].unique().tolist()

            self.labels = np.array(
                [self.list_of_labels.index(label) for label in self.dataset["label"]]
            )

        def __getitem__(self, idx):
            x = torch.tensor(self.encodings["input_ids"][idx])
            attention_mask = torch.tensor(self.encodings["attention_mask"][idx])
            y = self.labels[idx]
            sample_idx = self.dataset.id.iloc[idx]
            return sample_idx, x, attention_mask, y

        def __len__(self):
            return len(self.dataset)

    train_dataset = TextDataset(train_df, split="training")
    test_dataset = TextDataset(
        test_df, 
        split="validation",
        list_of_labels=train_dataset.list_of_labels,
    )

    # 🔭🌕 Registering the list of labels for the run
    dq.set_labels_for_run(train_dataset.list_of_labels)

    import torch
    import torch.nn.functional as F
    from torch.nn import Linear
    from transformers import AutoModel

    class TextClassificationModel(torch.nn.Module):
        """Defines a Pytorch text classification bert based model."""

        def __init__(self, num_labels: int):
            super().__init__()
            self.feature_extractor = AutoModel.from_pretrained("distilbert-base-uncased")
            self.classifier = Linear(self.feature_extractor.config.hidden_size, num_labels)

        def forward(self, x, attention_mask, ids):
            """Model forward function."""
            encoded_layers = self.feature_extractor(
                input_ids=x, attention_mask=attention_mask
            ).last_hidden_state
            classification_embedding = encoded_layers[:, 0]
            logits = self.classifier(classification_embedding)

            # 🔭🌕 Logging the model logits and embeddings
            dq.log_model_outputs(
                embs=classification_embedding, logits=logits, ids=ids
            )

            return logits


    BATCH_SIZE = 32
    MAX_NUM_EPOCHS = 100

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True,
    )
    val_dataloader = torch.utils.data.DataLoader(
        test_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False,
    )

    print("Initializing the DistilBERT model.")

    model = TextClassificationModel(num_labels=len(train_dataset.list_of_labels))
    model.to(device)

    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5
    )

    train_acc = torchmetrics.Accuracy()
    val_acc = torchmetrics.Accuracy()

    fig = go.FigureWidget(data={"y":[0], "x":[0]})
    # fig = go.FigureWidget()
    # fig.add_scatter(x=[0], y=[0], name="train loss")
    # fig.add_scatter(x=[0], y=[0], name="val loss")
    display(fig)
    fig.update_layout(
        title={
            'text': "Train Loss"
        },
        xaxis={"title": "Step"},
        yaxis={"title": "Loss"}
    )

    train_losses = []
    val_losses = []
    val_axis = []

    for epoch in range(MAX_NUM_EPOCHS):
        dq.set_epoch(epoch) # 🔭🌕 Setting the epoch

        model.train()
        running_loss = 0.0

        dq.set_split("training") # 🔭🌕 Setting split to training

        for data in train_dataloader:
            x_idxs, x, attention_mask, y = data
            x = x.to(device)
            attention_mask = attention_mask.to(device)
            y = torch.tensor(y, device=device)

            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            logits = model(x, attention_mask, x_idxs)
            loss = F.cross_entropy(logits, y)
            loss.backward()
            optimizer.step()

            train_loss = loss.item()
            train_losses.append(train_loss)
            # train_loss_fig = fig.data[0]
            # train_loss_fig.x, train_loss_fig.y = list(range(len(train_losses))), train_losses
            fig.update(data=[{"y":train_losses, "x": list(range(len(train_losses)))}])
            fig.update_layout(
                title={
                    'text': "Train Loss"
                },
                xaxis={"title": "Step"},
                yaxis={"title": "Loss"}
            )
            running_loss += train_loss
            train_acc(torch.argmax(logits.to("cpu"), 1), y.to("cpu"))

        model.eval()
        with torch.no_grad():
            dq.set_split("validation") # 🔭🌕 Setting split to validation

            val_loss = 0.0
            for data in val_dataloader:
                x_idxs, x, attention_mask, y = data

                x = x.to(device)
                attention_mask = attention_mask.to(device)
                y = torch.tensor(y, device=device)

                logits = model(x, attention_mask, x_idxs)
                loss = F.cross_entropy(logits, y)

                val_loss += loss.item()
                val_acc(torch.argmax(logits.to("cpu"), 1), y.to("cpu"))
            


            if epoch != 0:
                # Check if we should early stop
                if val_loss >= val_losses[-1]:
                    break

            val_losses.append(val_loss)

            # val_losses.append(val_loss / ct)
            # val_axis.append(len(train_losses))
            # val_loss_fig = fig.data[1]
            # val_loss_fig.x, val_loss_fig.y = val_axis, val_losses
            # print("[epoch %d] Validation loss: %.3f" % (epoch + 1, val_loss))
            # print(f"Val accuracy: {val_acc.compute()}")

    # output.disable_custom_widget_manager()
    print(f"Finished Training. Early Stopped at epoch {epoch}.")

    dq.finish() # 🔭🌕 Complete the Galileo workflow with a call to dq.finish()

"""
============================================================
END MODEL TRAINING CODE
============================================================
"""

import ipywidgets.widgets as widgets
from IPython.display import clear_output
import io
import pandas as pd
import dataquality as dq
from google.colab import output
import time
import plotly.graph_objects as go
import numpy as np
from scipy.optimize import curve_fit

from google.colab import data_table


GALILEO_HTML_EXAMPLE = """<center> <img
src=https://galileo-public-tutorial-data.s3.us-west-1.amazonaws.com/galileo_quickstart.svg alt='GalileoRing' width=600> <br> Choose an example dataset to get started.</center>"""
GALILEO_HTML_FILE = """<center>or:<br> Upload a .csv file with 2 columns: "text" and "label"</center>"""
GALILEO_HTML_TOKEN = """<center> <img
src=https://galileo-public-tutorial-data.s3.us-west-1.amazonaws.com/Logo.svg alt='GalileoLogo' width=200> <br> Now, copy the token from your <a
href="https://console.cloud.rungalileo.io/get-token" target="_blank">Galileo Console account</a> and paste it below:</center>"""
GALILEO_HTML_END = """<center>
<b>Tip: </b>If you don't have an account with Galileo, you can create one <a href="https://console.cloud.rungalileo.io/sign-up?utm=galileo-in-5-min" target="_blank">here</a></center>"""

box_layout = widgets.Layout(
    display="flex", flex_flow="column", align_items="center", width="85%"
)


newsgroups = widgets.Button(description="newsgroups")
trec6 = widgets.Button(description="trec6")
conv_intent = widgets.Button(description="conv_intent")
token_widget = widgets.Password(description="")
start_button = widgets.Button(description="Start!", icon="fa-rocket")
login_widget = widgets.VBox(
    [
        widgets.HTML(GALILEO_HTML_TOKEN),
        token_widget,
        start_button,
        widgets.HTML(GALILEO_HTML_END),
    ],
    layout=box_layout
)

upload_widget = widgets.FileUpload(accept='.csv', multiple=True)
starter = widgets.VBox(
    [
        widgets.HTML(GALILEO_HTML_EXAMPLE),
        widgets.HBox([trec6, conv_intent, newsgroups]),
        widgets.HTML(GALILEO_HTML_FILE), 
        upload_widget
     ],
        layout=box_layout
)

def restart(t):
    clear_output()
    newsgroups.disabled = False
    trec6.disabled = False
    conv_intent.disabled = False
    upload_widget.disabled = True
    display(starter)

def submit_event(t):
    token = token_widget.value
    # Erase token
    token_widget.value = ""
    clear_output()
    dq.config.token = token
    if not dq.clients.api.ApiClient().valid_current_user():
      print(
          "It seems as though the token you provided is not valid. Please try again"
      )
      display(login_widget)
      return

    email = dq.clients.api.ApiClient().get_current_user()["email"]
    dq.config.current_user = email
    dq.login()

    import warnings
    with warnings.catch_warnings():
      warnings.simplefilter("ignore")
      train_model()
    print("Start Over?")
    start_over_button = widgets.Button(description="Again!", icon="fa-rocket")
    start_over_button.on_click(restart)
    display(start_over_button)


start_button.on_click(submit_event)

display(starter)

df = None

def example_button_click(button):
    newsgroups.disabled = True
    trec6.disabled = True
    conv_intent.disabled = True   
    upload_widget.disabled = True
    datasets.logging.set_verbosity_error()
    datasets.logging.disable_progress_bar()
    global file_name
    file_name = f"rungalileo_dataset_{button.description}"
    ds = load_dataset(f"rungalileo/{button.description}")
    train_df = ds["train"].to_pandas()
    labels = ds["train"].features["label"].names
    train_df["label"] = train_df.label.apply(lambda row: labels[row])
    train_df["split"] = ["train"]*len(train_df)

    val_df = ds["validation"].to_pandas()
    val_df["label"] = val_df.label.apply(lambda row: labels[row])
    val_df["split"] = ["val"]*len(val_df)
    global df
    df = pd.concat([train_df, val_df])
    clear_output()
    # display(pd.concat([df[:5], df[-5:]]))
    # display(df)
    display(data_table.DataTable(df, include_index=False, num_rows_per_page=10))

    display(login_widget)

def on_upload(inputs):
    global file_name
    file_name = list(inputs['new'].keys())[-1]
    uploaded_file = list(inputs['new'].values())[-1]["content"]
    global df
    df = pd.read_csv(io.BytesIO(uploaded_file))
    df["id"] = list(range(len(df)))

    if not ("label" in df.columns and "text" in df.columns):
      clear_output()
      display(starter)
      print("CSV must have 'label' and 'text' columns. Please fix and try uploading again!")
      return
    clear_output()
    
    display(file_name)
    display(data_table.DataTable(df, include_index=False, num_rows_per_page=10))
    display(login_widget)

upload_widget.observe(on_upload, names='value')
newsgroups.on_click(example_button_click)
trec6.on_click(example_button_click)
conv_intent.on_click(example_button_click)


 ​