In [None]:
# -*- coding: utf-8 -*-
"""jigsaw_2tower_multi_rule_context_contrastive.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1nJNMkKf_9JMDQxDOgIHhrCpfw6eVjFe1
"""

# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# import os
# for dirname, _, filenames in os.walk('/kaggle/input'):
# for dirname, _, filenames in os.walk('.'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# !pip install pytorch_metric_learning


# ---- Imports
import random, math
import numpy as np
import pandas as pd
from dataclasses import dataclass

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold
from sklearn.metrics import roc_auc_score

from transformers import AutoTokenizer, AutoModel
from transformers import get_linear_schedule_with_warmup


# -----------------------------
# Load and preprocess data
# -----------------------------
# Use Kaggle paths when running on Kaggle
# MODEL_PATH = "/kaggle/input/xlm-roberta-base-offline/xlm_roberta_base_offline"
# MODEL_PATH = "C:/Users/satra/Downloads/xlm_roberta_base_offline"
MODEL_PATH = "xlm-roberta-base"


# trn = "/kaggle/input/jigsaw-agile-community-rules/train.csv"
# tst = "/kaggle/input/jigsaw-agile-community-rules/test.csv"
# trn = "/content/drive/MyDrive/Colab Notebooks/train.csv"
# tst = "/content/drive/MyDrive/Colab Notebooks/test.csv"
trn = "C:/Users/satra/Downloads/jigsaw-agile-community-rules/train.csv"
tst = "C:/Users/satra/Downloads/jigsaw-agile-community-rules/test.csv"

df_trn = pd.read_csv(trn)
df_tst = pd.read_csv(tst)


def fill_empty_examples_pandas(df):
    example_cols = ['positive_example_1', 'positive_example_2', 'negative_example_1', 'negative_example_2']
    for col in example_cols:
        df[col] = df[col].fillna('').astype(str)

    df['positive_example_1'] = df['positive_example_1'].mask(df['positive_example_1'] == '', df['positive_example_2'])
    df['positive_example_2'] = df['positive_example_2'].mask(df['positive_example_2'] == '', df['positive_example_1'])

    df['negative_example_1'] = df['negative_example_1'].mask(df['negative_example_1'] == '', df['negative_example_2'])
    df['negative_example_2'] = df['negative_example_2'].mask(df['negative_example_2'] == '', df['negative_example_1'])

    return df


def get_text(value):
    return str(value) if pd.notna(value) else ''


def extract_texts(row):
    return {
        "body": get_text(row["body"]),
        "rule": get_text(row["rule"]),
        "subreddit": get_text(row["subreddit"]),
        "pos1": f"{get_text(row['positive_example_1'])}",
        "pos2": f"{get_text(row['positive_example_2'])}",
        "neg1": f"{get_text(row['negative_example_1'])}",
        "neg2": f"{get_text(row['negative_example_2'])}",
    }

df_trn = fill_empty_examples_pandas(df_trn)
df_tst = fill_empty_examples_pandas(df_tst)

df_trn["inputs"] = df_trn.apply(extract_texts, axis=1)
df_tst["inputs"] = df_tst.apply(extract_texts, axis=1) # Apply to test data too

N_EPOCHS = 8
k_folds = 5

def build_rule_context(row: pd.Series):
    rule = str(row.get("rule", ""))
    pe = [x for x in [row.get("positive_example_1",""), row.get("positive_example_2","")] if x]
    ne = [x for x in [row.get("negative_example_1",""), row.get("negative_example_2","")] if x]

    pos_ctx = [f"<RULE> {rule} <POS> {p}" for p in pe]
    neg_ctx = [f"<RULE> {rule} <NEG> {n}" for n in ne]

    # keep a plain rule text if you want to add it later to the score as well:
    rule_only = f"<RULE> {rule}"
    return pos_ctx, neg_ctx, rule_only


class RuleTripleDataset(Dataset):
    """
    Yields:
      - anchor_* : tokenized BODY + SEP + RULE (paired)
      - pos_ctx_*: [P, L]  up to 2 positive contexts: "<RULE> r <POS> pe_i"
      - neg_ctx_*: [N, L]  up to 2 negative contexts: "<RULE> r <NEG> ne_i"
      - pos_mask : [P]     1 if a pos ctx exists in that slot, else 0
      - neg_mask : [N]     1 if a neg ctx exists in that slot, else 0
      - label    : float32 (train/val only)
    """
    def __init__(self, df, tokenizer, max_len=256, is_test=False, seed=42):
        self.df = df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.is_test = is_test
        self.sep = self.tokenizer.sep_token or " </s> "
        self.P_MAX, self.N_MAX = 2, 2  # two positives, two negatives

    def __len__(self): return len(self.df)

    def _tok(self, texts):
        return self.tokenizer(
            texts,
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt",
        )

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        body = str(row["body"])
        rule = str(row["rule"])

        # 1) Anchor: BODY + SEP + RULE
        anc_text = f"<BODY> {body}" + self.sep + f"<RULE> {rule}"
        anc = self._tok(anc_text)
        anchor_input_ids     = anc["input_ids"].squeeze(0).to(torch.long)
        anchor_attention_mask= anc["attention_mask"].squeeze(0).to(torch.long)

        # 2) Separate positive/negative contexts (0..2 each)
        pos_ctx, neg_ctx, _ = build_rule_context(row)

        # pad/truncate to fixed counts
        pos_ctx = (pos_ctx + [""]*self.P_MAX)[:self.P_MAX]
        neg_ctx = (neg_ctx + [""]*self.N_MAX)[:self.N_MAX]

        tok_pos = self._tok(pos_ctx)   # [P, L]
        tok_neg = self._tok(neg_ctx)   # [N, L]

        # masks: 1 if non-empty string, else 0
        pos_mask = torch.tensor([1 if s else 0 for s in pos_ctx], dtype=torch.float32)
        neg_mask = torch.tensor([1 if s else 0 for s in neg_ctx], dtype=torch.float32)

        item = {
            "anchor_input_ids": anchor_input_ids,
            "anchor_attention_mask": anchor_attention_mask,
            "pos_input_ids": tok_pos["input_ids"].to(torch.long),           # [P, L]
            "pos_attention_mask": tok_pos["attention_mask"].to(torch.long), # [P, L]
            "neg_input_ids": tok_neg["input_ids"].to(torch.long),           # [N, L]
            "neg_attention_mask": tok_neg["attention_mask"].to(torch.long), # [N, L]
            "pos_mask": pos_mask, "neg_mask": neg_mask,
        }
        if not self.is_test:
            item["label"] = torch.tensor(float(row["rule_violation"]), dtype=torch.float32)
        return item


def safe_normalize(x, eps=1e-6):
    return x / x.norm(dim=-1, keepdim=True).clamp_min(eps)


class DualEncoder(nn.Module):
    def __init__(self, model_name, hidden_size=None, proj_dim=256, dropout=0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        if hidden_size is None:
            hidden_size = getattr(self.encoder.config, "hidden_size", 768)
        self.dropout = nn.Dropout(dropout)
        self.proj_pair = nn.Linear(hidden_size, proj_dim)
        self.proj_cls  = nn.Linear(hidden_size, proj_dim)
        self.cls_head  = nn.Linear(proj_dim, 1)

    def _encode_hidden(self, input_ids, attention_mask):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls = out.last_hidden_state[:, 0]
        return self.dropout(cls)


    def forward(
        self,
        anchor_input_ids, anchor_attention_mask,
        pos_input_ids, pos_attention_mask, pos_mask,
        neg_input_ids, neg_attention_mask, neg_mask,
    ):
        B = anchor_input_ids.size(0)

        # Anchor
        h_anchor = self._encode_hidden(anchor_input_ids, anchor_attention_mask)   # [B, H]
        anchor_emb = nn.functional.normalize(self.proj_pair(h_anchor), dim=-1)    # [B, D]

        # Pos contexts: flatten -> encode -> reshape -> masked mean pool
        B, P, L = pos_input_ids.shape
        pos_flat_ids = pos_input_ids.reshape(B * P, L).contiguous()
        pos_flat_msk = pos_attention_mask.reshape(B * P, L).contiguous()
        h_pos_flat   = self._encode_hidden(pos_flat_ids, pos_flat_msk)            # [B*P, H]
        h_pos        = h_pos_flat.view(B, P, -1)                                  # [B, P, H]
        pos_mask_exp = pos_mask.unsqueeze(-1)                                     # [B, P, 1]
        h_pos = h_pos * pos_mask_exp
        pos_den = pos_mask.sum(dim=1, keepdim=True).clamp_min(1.0)                # [B, 1]
        h_pos_pool = (h_pos.sum(dim=1) / pos_den)                                 # [B, H]
        pos_emb = nn.functional.normalize(self.proj_pair(h_pos_pool), dim=-1)     # [B, D]

        # Neg contexts
        B2, N, L2 = neg_input_ids.shape
        assert B2 == B and L2 == L
        neg_flat_ids = neg_input_ids.reshape(B * N, L).contiguous()
        neg_flat_msk = neg_attention_mask.reshape(B * N, L).contiguous()
        h_neg_flat   = self._encode_hidden(neg_flat_ids, neg_flat_msk)            # [B*N, H]
        h_neg        = h_neg_flat.view(B, N, -1)                                  # [B, N, H]
        neg_mask_exp = neg_mask.unsqueeze(-1)                                     # [B, N, 1]
        h_neg = h_neg * neg_mask_exp
        neg_den = neg_mask.sum(dim=1, keepdim=True).clamp_min(1.0)                # [B, 1]
        h_neg_pool = (h_neg.sum(dim=1) / neg_den)                                 # [B, H]
        neg_emb = nn.functional.normalize(self.proj_pair(h_neg_pool), dim=-1)     # [B, D]

        anchor_emb = safe_normalize(self.proj_pair(h_anchor))
        pos_emb    = safe_normalize(self.proj_pair(h_pos_pool))
        neg_emb    = safe_normalize(self.proj_pair(h_neg_pool))

        # Classification branch (on the anchor)
        cls_feat = self.proj_cls(h_anchor)
        logits   = self.cls_head(cls_feat).squeeze(-1)                            # [B]

        return logits, anchor_emb, pos_emb, neg_emb

# ================================
# Hybrid Jigsaw ACR
# ================================

# ------------------
# CONFIG (edit here)
# ------------------
@dataclass
class CFG:
    model_name: str = "xlm-roberta-base"
    max_len: int = 256
    num_negatives: int = 4               # K (must be constant)
    batch_size: int = 16
    epochs: int = 3
    lr: float = 2e-5
    wd: float = 0.01
    warmup_ratio: float = 0.1
    dropout: float = 0.1
    proj_dim: int = 256
    use_amp: bool = True
    grad_clip: float = 1.0
    pair_margin: float = 0.2
    loss_alpha_cls: float = 0.6          # blend: alpha * BCE + (1-alpha) * pair
    seed: int = 42
    num_workers: int = 2
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

CFG = CFG()

# ----------------
# Reproducibility
# ----------------
def set_seed(s=42):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(CFG.seed)

# -------------
# Tokenizer
# -------------
tokenizer = AutoTokenizer.from_pretrained(CFG.model_name)

# -------------
# Losses
# -------------
bce_loss_fn = nn.BCEWithLogitsLoss()

def pairwise_margin_loss(anchor_emb, pos_emb, neg_emb, labels, margin=0.2):
    """
    labels: [B] float or long in {0,1}
    """
    # cosine since embeddings are L2-normalized
    sim_pos = (anchor_emb * pos_emb).sum(dim=-1)   # [B]
    sim_neg = (anchor_emb * neg_emb).sum(dim=-1)   # [B]

    # map labels to {-1, +1}: y=1 -> +1, y=0 -> -1
    ysign = (labels.float() * 2.0) - 1.0           # [B]

    # for y=1:  sim_pos - sim_neg;  for y=0: -(sim_pos - sim_neg) = sim_neg - sim_pos
    margin_term = ysign * (sim_pos - sim_neg)

    return torch.relu(margin - margin_term).mean()

# --------------------
# Training / Evaluation
# --------------------
def make_loader(df, is_test=False):
    ds = RuleTripleDataset(
        df, tokenizer,
        max_len=CFG.max_len,
        is_test=is_test,
        seed=CFG.seed
    )
    return DataLoader(
        ds,
        batch_size=CFG.batch_size,
        shuffle=not is_test,
        num_workers=CFG.num_workers,
        pin_memory=True,
        persistent_workers=(CFG.num_workers > 0) and (not is_test),
        drop_last=False
    )

def run_one_epoch(model, loader, optimizer=None, scheduler=None, scaler=None):
    is_train = optimizer is not None
    model.train() if is_train else model.eval()

    total_loss, nsteps = 0.0, 0
    all_probs, all_labels, all_logits = [], [], []

    for batch in loader:
        batch = {k: (v.to(CFG.device) if torch.is_tensor(v) else v) for k, v in batch.items()}

        with torch.cuda.amp.autocast(enabled=(CFG.use_amp and is_train)):
            logits, a_emb, p_emb, n_emb = model(
                anchor_input_ids=batch["anchor_input_ids"],
                anchor_attention_mask=batch["anchor_attention_mask"],
                pos_input_ids=batch["pos_input_ids"],
                pos_attention_mask=batch["pos_attention_mask"],
                pos_mask=batch["pos_mask"],
                neg_input_ids=batch["neg_input_ids"],
                neg_attention_mask=batch["neg_attention_mask"],
                neg_mask=batch["neg_mask"],
            )

            if "label" in batch:
                loss_cls = bce_loss_fn(logits, batch["label"])
                # label-aware margin
                loss_pair = pairwise_margin_loss(
                    a_emb, p_emb, n_emb, batch["label"], margin=CFG.pair_margin
                )
            else:
                loss_cls = torch.zeros((), device=CFG.device)
                loss_pair = torch.zeros((), device=CFG.device)

            loss = CFG.loss_alpha_cls * loss_cls + (1 - CFG.loss_alpha_cls) * loss_pair
            # print(f"Loss: {loss}")

        if is_train:
            optimizer.zero_grad(set_to_none=True)
            if CFG.use_amp:
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                if CFG.grad_clip is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.grad_clip)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                if CFG.grad_clip is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.grad_clip)
                optimizer.step()

            if scheduler is not None:
                scheduler.step()

        total_loss += loss.item()
        nsteps += 1

        if "label" in batch:
            all_logits.extend(logits.detach().cpu().numpy().tolist())
            all_probs.extend(torch.sigmoid(logits).detach().cpu().numpy().tolist())
            all_labels.extend(batch["label"].detach().cpu().numpy().tolist())

    avg_loss = total_loss / max(nsteps, 1)
    auc = roc_auc_score(all_labels, all_probs) if all_labels else None
    return avg_loss, auc, (np.array(all_probs), np.array(all_labels), np.array(all_logits))


def train_hybrid(df_train, df_val, df_test=None):
    model = DualEncoder(
        model_name=CFG.model_name,
        proj_dim=CFG.proj_dim,
        dropout=CFG.dropout
    ).to(CFG.device)

    # Optimizer & Scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)
    steps_per_epoch = math.ceil(len(df_train) / CFG.batch_size)
    num_train_steps = steps_per_epoch * CFG.epochs
    num_warmup_steps = int(num_train_steps * CFG.warmup_ratio)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps
    )

    scaler = torch.cuda.amp.GradScaler(enabled=CFG.use_amp)

    # Loaders
    train_loader = make_loader(df_train, is_test=False)
    val_loader   = make_loader(df_val,   is_test=False)
    test_loader  = make_loader(df_test,  is_test=True) if df_test is not None else None

    best_auc, best_state = -1.0, None

    for epoch in range(1, CFG.epochs + 1):
        tr_loss, tr_auc, _ = run_one_epoch(model, train_loader, optimizer=optimizer, scaler=scaler, scheduler=scheduler)
        val_loss, val_auc, _ = run_one_epoch(model, val_loader, optimizer=None, scaler=None, scheduler=None)
        print(f"Epoch {epoch:02d} | TrainLoss {tr_loss:.4f}"
              f"{' | TrainAUC ' + f'{tr_auc:.4f}' if tr_auc is not None else ''}"
              f" | ValLoss {val_loss:.4f} | ValAUC {val_auc:.4f}")

        if val_auc is not None and val_auc > best_auc:
            best_auc = val_auc
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}

    if best_state is not None:
        model.load_state_dict(best_state)
    print(f"Best Val AUC: {best_auc:.4f}" if best_auc >= 0 else "No validation AUC computed.")

    # --- collect VAL predictions (logits + probs) for OOF ---
    model.eval()
    val_probs_list, val_logits_list, val_labels_list = [], [], []
    with torch.no_grad():
        for batch in val_loader:
            batch = {k: (v.to(CFG.device) if torch.is_tensor(v) else v) for k, v in batch.items()}
            logits, _, _, _ = model(
                anchor_input_ids=batch["anchor_input_ids"],
                anchor_attention_mask=batch["anchor_attention_mask"],
                pos_input_ids=batch["pos_ctx_input_ids"],
                pos_attention_mask=batch["pos_ctx_attention_mask"],
                pos_mask=batch["pos_mask"],
                neg_input_ids=batch["neg_ctx_input_ids"],
                neg_attention_mask=batch["neg_ctx_attention_mask"],
                neg_mask=batch["neg_mask"],
            )
            val_logits_list.extend(logits.cpu().numpy().tolist())
            val_probs_list.extend(torch.sigmoid(logits).cpu().numpy().tolist())
            val_labels_list.extend(batch["label"].cpu().numpy().tolist())
    val_logits = np.array(val_logits_list)
    val_probs  = np.array(val_probs_list)
    val_labels = np.array(val_labels_list)

    # --- TEST predictions (if provided) ---
    test_logits = test_probs = None
    if test_loader is not None:
        all_logits = []
        with torch.no_grad():
            for batch in test_loader:
                batch = {k: (v.to(CFG.device) if torch.is_tensor(v) else v) for k, v in batch.items()}
                logits, _, _, _ = model(
                    anchor_input_ids=batch["anchor_input_ids"],
                    anchor_attention_mask=batch["anchor_attention_mask"],
                    pos_input_ids=batch["pos_ctx_input_ids"],
                    pos_attention_mask=batch["pos_ctx_attention_mask"],
                    pos_mask=batch["pos_mask"],
                    neg_input_ids=batch["neg_ctx_input_ids"],
                    neg_attention_mask=batch["neg_ctx_attention_mask"],
                    neg_mask=batch["neg_mask"],
                )
                all_logits.extend(logits.cpu().numpy().tolist())
        test_logits = np.array(all_logits)
        test_probs  = 1.0 / (1.0 + np.exp(-test_logits))  # sigmoid

    return model, val_probs, val_labels, val_logits, test_probs, test_logits


def _logit(p, eps=1e-6):
    p = np.clip(p, eps, 1 - eps)
    return np.log(p / (1 - p))

def _sigmoid(x):  # numpy
    return 1.0 / (1.0 + np.exp(-x))

def kfold_train(df_trn, df_tst, n_splits=5, group_col="subreddit", seeds=(42, 1337, 2029)):
    y = (df_trn["rule_violation"] > 0.5).astype(int).values
    use_group = group_col in df_trn.columns and df_trn[group_col].nunique() >= n_splits

    oof_logits_all = np.zeros(len(df_trn), dtype=float)
    test_logits_seeds = []   # one ensembled test-logit vector per seed
    fold_auc_seeds = []

    for s_idx, seed in enumerate(seeds):
        set_seed(seed)

        if use_group:
            groups = df_trn[group_col].astype(str).values
            cv = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=seed)
            fold_iter = cv.split(df_trn, y, groups)
            print(f"\n[seed {seed}] Using StratifiedGroupKFold by '{group_col}' ({df_trn[group_col].nunique()} groups).")
        else:
            cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)
            fold_iter = cv.split(df_trn, y)
            print(f"\n[seed {seed}] Using StratifiedKFold.")

        oof_logits = np.zeros(len(df_trn), dtype=float)
        fold_aucs = []
        test_fold_logits = []

        for fold, (tr_idx, va_idx) in enumerate(fold_iter):
            print(f"=== Seed {seed} | Fold {fold+1}/{n_splits} ===")
            tr_df = df_trn.iloc[tr_idx].reset_index(drop=True)
            va_df = df_trn.iloc[va_idx].reset_index(drop=True)

            # fresh seed per fold for dataloaders/shuffling
            set_seed(seed + fold)

            model, val_probs, val_labels, val_logits, test_probs, test_logits = train_hybrid(tr_df, va_df, df_tst)
            fold_auc = roc_auc_score(val_labels, val_probs)
            print(f"[seed {seed} fold {fold}] Val AUC: {fold_auc:.6f}")

            oof_logits[va_idx] = val_logits
            fold_aucs.append(fold_auc)
            if test_logits is not None:
                test_fold_logits.append(test_logits)

        # fold-weighted logit ensembling for this seed
        if test_fold_logits:
            W = np.array(fold_aucs); W = W / (W.sum() + 1e-12)
            test_logits_ens = np.average(np.vstack(test_fold_logits), axis=0, weights=W)
            test_logits_seeds.append(test_logits_ens)
        else:
            test_logits_ens = None

        # stash seed’s OOF logits
        oof_logits_all += oof_logits
        fold_auc_seeds.append(fold_aucs)

        # report seed CV
        cv_auc_seed = roc_auc_score(y, _sigmoid(oof_logits))
        print(f"[seed {seed}] CV AUC: {cv_auc_seed:.6f} | per-fold: {[round(a,6) for a in fold_aucs]}")

    # average OOF logits across seeds
    oof_logits_all /= max(len(seeds), 1)
    cv_auc = roc_auc_score(y, _sigmoid(oof_logits_all))
    print(f"\nOverall CV AUC (seeds {len(seeds)} × folds {n_splits}): {cv_auc:.6f}")

    # average test logits across seeds (simple mean)
    pred_test = None
    if test_logits_seeds:
        pred_test = _sigmoid(np.mean(np.vstack(test_logits_seeds), axis=0))

    return _sigmoid(oof_logits_all), pred_test, fold_auc_seeds, cv_auc


oof_probs, test_probs, fold_auc_seeds, cv_auc = kfold_train(
    df_trn=df_trn,
    df_tst=df_tst,
    n_splits=5,           # cap to available groups if needed
    group_col="subreddit",
    seeds=(42, 1337, 2029)
)

print("Final CV AUC:", cv_auc)
if test_probs is not None:
    sub = pd.DataFrame({"row_id": df_tst["row_id"], "rule_violation": test_probs})
    sub.to_csv("submission_kfold.csv", index=False)
    print("Saved submission_kfold.csv")

# from google.colab import drive
# drive.mount('/content/drive')