In [17]:
import os
import requests
import tarfile
import shutil
import random

PROJECT_ID = "text Classification" # Set this to your project name
BUCKET_URI = f"gs://{PROJECT_ID}-imdb-email-dataset"
REGION = "us-central1"

if not PROJECT_ID:
    raise ValueError("You must set a non-empty PROJECT_ID and make sure the project is created in GCP") 
    
os.environ["PROJECT_ID"] = PROJECT_ID
os.environ["REGION"] = REGION
os.environ["BUCKET_URI"] = BUCKET_URI

In [None]:
url = "https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"

def maybe_download(url:str, download_dir:str="data"):
    """ Download the dataset from the web and extract the data """
    _, filename = os.path.split(url)
    download_path = os.path.join(download_dir, filename)

    if not os.path.exists(download_path):
        print("Downloading data")
        response = requests.get(url)
        os.makedirs(download_dir, exist_ok=True)

        with open(download_path, "wb") as f:
            f.write(response.content)
    else:
        print("File is already downloaded")

    tar = tarfile.open(download_path, "r:gz")
    tar.extractall(os.path.join(download_dir,"unzip"))
    tar.close()

maybe_download(url)

In [None]:
def generate_truncated_dataset(dataset_dir:str, n:int=2000):
    """It may take painstakenly long time to create a vertex AI dataset.
    Therefore we're going to limit the data we'll be using"""

    n_per_label = n//2
    parent_dir, _ = os.path.split(dataset_dir)
    truncated_dir = os.path.join(parent_dir, "unzip_truncated")
    os.makedirs(truncated_dir, exist_ok=True)

    for sub_dir in [
        os.path.join("aclImdb", "train", "neg"), 
        os.path.join("aclImdb", "train", "pos"), 
        os.path.join("aclImdb", "test", "neg"), 
        os.path.join("aclImdb", "test", "pos")
    ]:

        os.makedirs(os.path.join(truncated_dir, sub_dir), exist_ok=True)
        for f in os.listdir(os.path.join(dataset_dir, sub_dir)):
            sample_id = int(f.split("_")[0])
            if sample_id < n_per_label:
                shutil.copy(os.path.join(dataset_dir, sub_dir, f), os.path.join(truncated_dir, sub_dir, f))

generate_truncated_dataset(os.path.join("data/unzip"))

In [18]:
import pydantic 
from typing import Dict,List, Tuple
from typing_extensions import Literal

# Format from https://cloud.google.com/vertex-ai/docs/text-data/classification/prepare-data
class ClassificationAnnotation(pydantic.BaseModel):
    displayName: Literal["positive", "negative"]

class DataItemResourceLabels(pydantic.BaseModel):
    ml_use: Literal["training", "validation", "test"] = pydantic.Field(alias="aiplatform.googleapis.com/ml_use")
    # Enables us to use ml_use=<x> instead of the long field name
    class Config:
        allow_population_by_field_name = True

class TextClassificationSample(pydantic.BaseModel):
    textContent: str
    classificationAnnotation: ClassificationAnnotation 
    dataItemResourceLabels: DataItemResourceLabels

instance = TextClassificationSample(
    textContent="some review text", 
    classificationAnnotation=ClassificationAnnotation(displayName="positive"),
    dataItemResourceLabels=DataItemResourceLabels(ml_use="training")
)

print(instance.json(by_alias=True, indent=4))

{
    "textContent": "some review text",
    "classificationAnnotation": {
        "displayName": "positive"
    },
    "dataItemResourceLabels": {
        "aiplatform.googleapis.com/ml_use": "training"
    }
}


In [19]:
from google.cloud import storage 
random.seed(946021)


def read_data(file_path:str) -> str:
    """ Read a text file from a given path """
    with open(file_path, "r") as f:
        data = f.read()
    
    # Solving the non-interchangeable valid content error during data import
    data_processed = data.replace("\u0085", " ").replace("\u0096", " ")
    return data_processed

def generate_single_instance(file_path:str, ml_use:str) -> TextClassificationSample:
    """ Given a filepath, create a single TextClassificationSample instance """
    label = None
    if "pos" in file_path:
        label = "positive"
    elif "neg" in file_path:
        label = "negative"
    if label:
        instance = TextClassificationSample(
            textContent=read_data(file_path), 
            classificationAnnotation=ClassificationAnnotation(displayName=label),
            dataItemResourceLabels=DataItemResourceLabels(ml_use=ml_use)
        )
        return instance
    else:
        raise ValueError("label cannot be None")
        
def create_instances(data_dir: str) -> Tuple[List[TextClassificationSample], List[TextClassificationSample]]:

    train_subdir = "train"
    test_subdir = "test"
    train_instances = []

    print(f"Reading training data from the GCS bucket")
    for root, _, files in os.walk(os.path.join(data_dir, train_subdir)):
        for fname in files:
            fpath = os.path.join(root, fname)
            if fpath.endswith(".txt"):
                instance = generate_single_instance(
                    file_path=fpath, ml_use="training"
                )
                train_instances.append(instance.json(by_alias=True, ensure_ascii=False)+'\n')

    print(f"\tFound {len(train_instances)} train instances")

    test_instances = []
    valid_count, test_count = 0,0

    print(f"Reading test data from the GCS bucket")
    for root, _, files in os.walk(os.path.join(data_dir, test_subdir)):
        for fname in files:
            fpath = os.path.join(root, fname)
            if fpath.endswith(".txt"):
                if random.uniform(0,1.0)<0.5:
                    valid_count += 1
                    ml_use="validation"
                else:
                    test_count += 1
                    ml_use="test"
                
                data = read_data(fpath)

                instance = generate_single_instance(
                    file_path=fpath, ml_use=ml_use
                )
                test_instances.append(instance.json(by_alias=True, ensure_ascii=False)+'\n')

    print(f"\tFound {valid_count} validation instances and {test_count} test instances")

    return train_instances, test_instances

train_instances, test_instances = create_instances(os.path.join("data", "unzip_truncated", "aclImdb"))

Reading training data from the GCS bucket
	Found 2001 train instances
Reading test data from the GCS bucket
	Found 1013 validation instances and 987 test instances


In [20]:
import json

with open(os.path.join("data", "train_instances.jsonl"), "w") as f:
    f.writelines(train_instances)

with open(os.path.join("data", "test_instances.jsonl"), "w") as f:
    f.writelines(test_instances)

In [21]:
!gcloud config set project $PROJECT_ID

Updated property [core/project].


In [22]:
%%bash

if gsutil ls $BUCKET_URI; then
    echo "Bucket ${BUCKET_URI} already exists.";
else
    echo "Bucket ${BUCKET_URI} doesn't exist. Creating a new one"
    gsutil mb -l $REGION -p $PROJECT_ID $BUCKET_URI
fi


gs://cust-eng-tac-tools-dev-imdb-email-dataset/test_instances.jsonl
gs://cust-eng-tac-tools-dev-imdb-email-dataset/train_instances.jsonl
Bucket gs://cust-eng-tac-tools-dev-imdb-email-dataset already exists.


In [15]:
%%bash

if gsutil ls $BUCKET_URI; then
    echo "Bucket ${BUCKET_URI} already exists.";
else
    echo "Bucket ${BUCKET_URI} doesn't exist. Creating a new one"
    gsutil mb -l $REGION -p $PROJECT_ID $BUCKET_URI
fi


Bucket gs://cust-eng-tac-tools-dev-imdb-email-dataset doesn't exist. Creating a new one


BucketNotFoundException: 404 gs://cust-eng-tac-tools-dev-imdb-email-dataset bucket does not exist.
Creating gs://cust-eng-tac-tools-dev-imdb-email-dataset/...


In [23]:
!gsutil cp data/train_instances.jsonl $BUCKET_URI
!gsutil cp data/test_instances.jsonl $BUCKET_URI

Copying file://data/train_instances.jsonl [Content-Type=application/octet-stream]...
/ [1 files][  2.8 MiB/  2.8 MiB]                                                
Operation completed over 1 objects/2.8 MiB.                                      
Copying file://data/test_instances.jsonl [Content-Type=application/octet-stream]...
- [1 files][  2.8 MiB/  2.8 MiB]                                                
Operation completed over 1 objects/2.8 MiB.                                      
