In [3]:
# Clone the diffusion model into our colab environment
!git clone https://github.com/lucidrains/denoising-diffusion-pytorch.git
%cd denoising-diffusion-pytorch/
!pip install -e .
!pip install pypng

print(f"Git Clone finished at {datetime.now()}")

Cloning into 'denoising-diffusion-pytorch'...
remote: Enumerating objects: 1884, done.[K
remote: Total 1884 (delta 0), reused 0 (delta 0), pack-reused 1884 (from 1)[K
Receiving objects: 100% (1884/1884), 2.57 MiB | 30.93 MiB/s, done.
Resolving deltas: 100% (1336/1336), done.
/content/denoising-diffusion-pytorch/denoising-diffusion-pytorch
Obtaining file:///content/denoising-diffusion-pytorch/denoising-diffusion-pytorch
  Preparing metadata (setup.py) ... [?25l[?25hdone
Installing collected packages: denoising-diffusion-pytorch
  Attempting uninstall: denoising-diffusion-pytorch
    Found existing installation: denoising-diffusion-pytorch 2.1.1
    Uninstalling denoising-diffusion-pytorch-2.1.1:
      Successfully uninstalled denoising-diffusion-pytorch-2.1.1
  Running setup.py develop for denoising-diffusion-pytorch
Successfully installed denoising-diffusion-pytorch-2.1.1
Collecting pypng
  Downloading pypng-0.20220715.0-py3-none-any.whl.metadata (13 kB)
Downloading pypng-0.2022071

In [4]:
# Get the imports; check if the gpu is available
import time
from typing import List, Dict
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import os
import numpy as np
from PIL import Image
from datetime import datetime
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer
import matplotlib.pyplot as plt
import pandas as pd
import io
import re
import struct
import png
from google.colab import files

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Can I can use GPU now? -- {torch.cuda.is_available()}')

print(f"Imports finished at {datetime.now()}")

Can I can use GPU now? -- False
Imports finished at 2025-04-12 15:04:14.421964


In [5]:
# NOTE: This model was trained using the ROCS dataset, to recreate the work, you will need to provide the same file I used

print(f"This is a manual interaction code section. Click the 'Choose Files' button, and please select the 'ROCS_winter2018.csv' file")
uploaded = files.upload()

This is a manual interaction code section. Click the 'Choose Files' button, and please select the 'ROCS_winter2018.csv' file


Saving ROCS_winter2018.csv to ROCS_winter2018.csv


In [15]:
# This method is used to embed all of the example sentences into pngs as discussed in the paper

df = pd.read_csv(io.BytesIO(uploaded['ROCS_winter2018.csv']))

print(f"Absobed input data. The shape of the data is {df.shape}")

sent1 = df['InputSentence1']
sent2 = df['InputSentence2']
sent3 = df['InputSentence3']
sent4 = df['InputSentence4']

numSamples = len(sent1)

# This will be explained in more detail later, but these are special words to identify the start and end of sentences
startWord = 'sss'
endWord = 'eee'

# This method will be used to cut to the words of the sentence.
# Keeping apostrophes, but removing other punctuation
def tokenize(text: str) -> List[str]:
    word_re = r"\b[A-Za-z]+(?:'[A-Za-z]+)?\b"
    words = [w.lower() for w in re.findall(word_re, text)]
    return words

# DEBUG ONLY
# example = "He said, 'Isn't O'Brian the best?'"
# print(example.split())
# print(tokenize(example))

# This block for generating a word bank, there's only
# about 8500 words in the train data file that I've chosen
wordBank = {}

# Word index is what we will use to look up words in the markov matrix
wordIndex = {}
reverseIndex = {}
indexIterator = 0

# Read through all the sentences
for i in range(numSamples):

  sent1Words = tokenize(sent1[i])
  sent2Words = tokenize(sent2[i])
  sent3Words = tokenize(sent3[i])
  sent4Words = tokenize(sent4[i])

  allSentences = [sent1Words, sent2Words, sent3Words, sent4Words]
  for sentence in allSentences:
    for j in range(len(sentence)):

      # Count every word, we're going to use the most frequent words
      if sentence[j] not in wordBank:
        wordBank[sentence[j]] = 1
        wordIndex[sentence[j]] = indexIterator
        reverseIndex[indexIterator] = sentence[j]
        indexIterator += 1
      else:
        wordBank[sentence[j]] += 1

# Add the start word and the end word
wordBank[startWord] = 1
wordIndex[startWord] = indexIterator
indexIterator += 1

wordBank[endWord] = 1
wordIndex[endWord] = indexIterator
indexIterator += 1

# For testing
#print(wordBank)
#print(f"The length of the wordbank is {len(wordBank)}")
#print(f"The length of the word index is {indexIterator}")

output_folder = './sentenceImages/'
if not os.path.exists(output_folder):
        os.makedirs(output_folder)


# When provided a list of tokenized words, this method will save them to a png file
def convertSentenceToBytes(tokenizedSentence: List[str], fileNumber) -> List[bytes]:
  # Every word in each sentence should be 2 bytes long. We're going to find the number in that each word translates to in the sentence, and break it into 2 bytes
  zeroString = '0000000000000000'

  # The dimenstions of the image that we're going to make
  height = 12
  width = 12
  channels = 1
  byteArray = np.zeros(height * width*channels, dtype = np.uint8 )
  imageTensor = np.zeros((height, width, channels), dtype = np.uint8)
  index = 0

  for word in tokenizedSentence:
    wordNumber = wordIndex[word]
    wordBytes = bin(wordNumber)[2:]
    wordBytes = zeroString[:(16-len(wordBytes))] + wordBytes
    firstByteString = wordBytes[:8]
    secondByteString = wordBytes[8:]
    #print(f"Word: {word}; wordNumber: {wordNumber}; wordBytes: {wordBytes}; length: {len(wordBytes)}; firstByte: {firstByteString}; secondByte: {secondByteString}")

    firstByte = struct.pack('B',int(firstByteString, 2))
    secondByte = struct.pack('B', int(secondByteString, 2))

    #print(f"First Byte: {firstByte}; Second Byte: {secondByte}")

    byteArray[index] = int(firstByteString, 2)
    index += 1
    byteArray[index] = int(secondByteString, 2)
    index += 1

  img = Image.fromarray(byteArray, mode='L')  # 'L' mode for grayscale images

  postReadingBytes = io.BytesIO()

  img.save(postReadingBytes,format='PNG' )

  imgData = postReadingBytes.getvalue()

  #print(byteArray)
  #print(imgData)

  filePath = "sentenceImages/sentenceGroup" + str(fileNumber) + ".png"
  img.save(filePath)


# Loop over all sentences and generate images
for i in range(numSamples):

  sent1Words = tokenize(sent1[i])
  sent2Words = tokenize(sent2[i])
  sent3Words = tokenize(sent3[i])
  sent4Words = tokenize(sent4[i])

  allSentences = []
  allSentences.extend(sent1Words)
  allSentences.extend(sent2Words)
  allSentences.extend(sent3Words)
  allSentences.extend(sent4Words)
  convertSentenceToBytes(allSentences, i)

print(f"Text to images completed at: {datetime.now()}")

Absobed input data. The shape of the data is (3142, 8)
Text to images completed at: 2025-04-12 15:13:48.260691


In [23]:
# Use this cell to train

# Specify the output folder
batch_size = 32

model = Unet(
    dim = 64,
    dim_mults = (1, 2, 4),
    channels = 1
)

diffusion = GaussianDiffusion(
    model,
    image_size = 12,
    timesteps = 1000    # number of steps; decrease this to speedup, quality will suffer
)

# Define the trainer

trainer = Trainer(
    diffusion,
    output_folder,
    train_batch_size = batch_size,
    train_lr = 2e-4,
    train_num_steps = 100000,  # Total training steps; decrease this to speedup, quality will suffer
    gradient_accumulate_every = 2,
    ema_decay = 0.995,
    amp = True,
    calculate_fid = False
)

trainer.train()

print(datetime.now())

  0%|          | 0/100 [00:00<?, ?it/s]

training complete
2025-04-12 16:55:00.741047


In [14]:
import shutil
if os.path.exists(output_folder):
    shutil.rmtree(output_folder)
else:
    print(f"The directory: {output_folder} does not exist")

The directory: ./sentenceImages/ does not exist


In [24]:
# Use this cell to generate

samples = diffusion.sample(batch_size = 1)
samples.shape # (1, 1, 12, 12)

genFolder = "./Generated/"

if not os.path.exists(genFolder):
        os.makedirs(genFolder)

for i in range(samples.size(0)):
        # Get the image tensor and convert it to a numpy array
        img_tensor = samples[i]
        img_array = img_tensor.numpy()

        img_array = (img_array * 255).astype(np.uint8)
        # print(img_array.shape)
        # print(img_array)

        # Convert the numpy array to a PIL Image
        img = Image.fromarray(img_array[0], mode='L')  # 'L' mode for grayscale images

            # Define the file path and save the image
        file_path = os.path.join(genFolder, f'image_0.png')
        img.save(file_path)
        print(f'Saved {file_path}')


print(f"Image generated: {datetime.now()}")


sampling loop time step:   0%|          | 0/5 [00:00<?, ?it/s]

Saved ./Generated/image_0.png
Image generated: 2025-04-12 17:00:43.496299


In [25]:
# Use this cell to decode the images

# The file to read
filePath = "./Generated/image_0.png"

# The word string
wordString = ""


#############################################################################################
# Open the image and read the sentence
#############################################################################################
image = Image.open(filePath)

byte_io = io.BytesIO()

image.save(byte_io, format='PNG')

testBytes = image.tobytes()
# print(testBytes)


count = 0
firstByte = ""
secondByte = ""
completeByte = ""
foundFirstWord = False
foundFirstZero = False

for byte in testBytes:

    if firstByte == "":
        firstByte = byte << 8
        continue

    else:
        secondByte = byte
        completeByte = firstByte + secondByte

        #print(int(completeByte))

        #secondByte = b''

    # The exit condition is if we find two zeros in a row

    if (foundFirstWord == True) and (int(completeByte) == 0):
        if foundFirstZero == False:
           foundFirstZero = True
           firstByte = ""
           continue
        else:
            break

    # print(f"First byte: {firstByte}; Second Byte {secondByte}; Complete Byte {completeByte}")

    if completeByte >= len(wordBank):
      wordString = wordString + " " + "JUNK_WORD_OUT_OF_CONTEXT"
    else:
      wordString = wordString + " " + reverseIndex[completeByte]

    foundFirstWord = True
    count += 1
    firstByte = ""
    #print(count)


print(wordString)


print(datetime.now())


 JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT JUNK_WORD_OUT_OF_CONTEXT nation ruth gloves nacho